1

I have lambda function that is screwing up the pickling of an object. What makes it hard to debug is that it doesn't tell me the name of field causing this issue. I created a recursive function that tries to find such fields but it fails on the code I need it to work on (but succeeds in toy self contained cases).

It works in this toy example:

# %%
"""
trying to detect which field is the anonymous function giving me isse since:
    AttributeError: Can't pickle local object 'FullOmniglot.__init__.<locals>.<lambda>'
doesn't tell me which one for some reason.
"""
import re
from typing import Any, Callable, Union, Optional


def _is_anonymous_function(f) -> bool:
    """
    Returns true if it's an anonynouys function.

    ref: https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda
    """
    return callable(f) and f.__name__ == "<lambda>"


def _get_anonymous_function_attributes(anything, halt: bool = False, verbose: bool = False) -> dict:
    """
    Returns the dictionary of name of fields to anonymous functions in the past anything thing.

    :param anything:
    :param halt:
    :param verbose:
    :return:
    """
    anons: dict = {}
    for field_name in dir(anything):
        field = getattr(anything, field_name)
        if _is_anonymous_function(field):
            if verbose:
                print(f'{field_name=}')
                print(f'{field=}')
            if halt:
                from pdb import set_trace as st
                st()
            anons[str(field_name)] = field
    return anons

def _get_anonymous_function_attributes_recursive(anything: Any, path: str = '') -> dict[str, Callable]:
    """"""
    anons: dict = {}
    def __get_anonymous_function_attributes_recursive(anything: Any,
                                                      path: Optional[str] = '',
                                                      ) -> None:
        if _is_anonymous_function(anything):
            # assert field is anything, f'Err not save thing/obj: \n{field=}\n{anything=}'
            # key: str = str(dict(obj=anything, field_name=field_name))
            key: str = str(path)
            anons[key] = anything
        else:
            for field_name in dir(anything):
                # if field_name != '__abstractmethods__':
                if not bool(re.search(r'__(.+)__', field_name)):
                    field = getattr(anything, field_name)
                    # only recurse if new field is not itself
                    if field is not anything:  # avoids infinite recursions
                        path_for_this_field = f'{path}.{field_name}'
                        __get_anonymous_function_attributes_recursive(field, path_for_this_field)
        return
    __get_anonymous_function_attributes_recursive(anything, path)
    return anons

class MyObj:
    def __init__(self):
        self.data = 'hi'
        self.anon = lambda x: x
        local_variable_me = 'my a local variable!'

    def non_anon(self, x):
        return x

class MyObj2:
    def __init__(self):
        self.data = 'hi'
        self.anon = lambda x: x
        local_variable_me = 'my a local variable!'

        self.obj = MyObj()

    def non_anon(self, x):
        return x


"""
Trying to fix: AttributeError: Can't pickle local object 'FullOmniglot.__init__.<locals>.<lambda>'
Trying to approximate with my obj and get: obj.__init__.<locals> to to get the obj.__ini__.<locals>.<lambda> 
"""
top_obj = MyObj2()
# print(f'anons recursive: {_get_anonymous_function_attributes_recursive(obj)=}')
print('getting all anonymous functions recursively: ')
anons: dict = _get_anonymous_function_attributes_recursive(top_obj, 'top_obj')
print(f'{len(anons.keys())=}')
for k, v in anons.items():
    print()
    print(f'{k=}')
    print(f'{v=}')
    # print(k, v)
print()

but fails in the wild pytorch code:

# %%
"""
pip install torch
pip install learn2learn
"""
print()
import learn2learn
from torch.utils.data import DataLoader

omni = learn2learn.vision.benchmarks.get_tasksets('omniglot', root='~/data/l2l_data')
loader = DataLoader(omni, num_workers=1)
next(iter(loader))
print()

with error:

Traceback (most recent call last):
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/Users/brandomiranda/opt/anaconda3/envs/meta_learning/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'FullOmniglot.__init__.<locals>.<lambda>'

Why does it fail there?


Full self contained reproducible code in one place:

# %%
"""
trying to detect which field is the anonymous function giving me isse since:
    AttributeError: Can't pickle local object 'FullOmniglot.__init__.<locals>.<lambda>'
doesn't tell me which one for some reason.
"""
import re
from typing import Any, Callable, Union, Optional


def _is_anonymous_function(f) -> bool:
    """
    Returns true if it's an anonynouys function.

    ref: https://stackoverflow.com/questions/3655842/how-can-i-test-whether-a-variable-holds-a-lambda
    """
    return callable(f) and f.__name__ == "<lambda>"


def _get_anonymous_function_attributes(anything, halt: bool = False, verbose: bool = False) -> dict:
    """
    Returns the dictionary of name of fields to anonymous functions in the past anything thing.

    :param anything:
    :param halt:
    :param verbose:
    :return:
    """
    anons: dict = {}
    for field_name in dir(anything):
        field = getattr(anything, field_name)
        if _is_anonymous_function(field):
            if verbose:
                print(f'{field_name=}')
                print(f'{field=}')
            if halt:
                from pdb import set_trace as st
                st()
            anons[str(field_name)] = field
    return anons

def _get_anonymous_function_attributes_recursive(anything: Any, path: str = '') -> dict[str, Callable]:
    """"""
    anons: dict = {}
    def __get_anonymous_function_attributes_recursive(anything: Any,
                                                      path: Optional[str] = '',
                                                      ) -> None:
        if _is_anonymous_function(anything):
            # assert field is anything, f'Err not save thing/obj: \n{field=}\n{anything=}'
            # key: str = str(dict(obj=anything, field_name=field_name))
            key: str = str(path)
            anons[key] = anything
        else:
            for field_name in dir(anything):
                # if field_name != '__abstractmethods__':
                if not bool(re.search(r'__(.+)__', field_name)):
                    field = getattr(anything, field_name)
                    # only recurse if new field is not itself
                    if field is not anything:  # avoids infinite recursions
                        path_for_this_field = f'{path}.{field_name}'
                        __get_anonymous_function_attributes_recursive(field, path_for_this_field)
        return
    __get_anonymous_function_attributes_recursive(anything, path)
    return anons

class MyObj:
    def __init__(self):
        self.data = 'hi'
        self.anon = lambda x: x
        local_variable_me = 'my a local variable!'

    def non_anon(self, x):
        return x

class MyObj2:
    def __init__(self):
        self.data = 'hi'
        self.anon = lambda x: x
        local_variable_me = 'my a local variable!'

        self.obj = MyObj()

    def non_anon(self, x):
        return x


"""
Trying to fix: AttributeError: Can't pickle local object 'FullOmniglot.__init__.<locals>.<lambda>'
Trying to approximate with my obj and get: obj.__init__.<locals> to to get the obj.__ini__.<locals>.<lambda> 
"""
top_obj = MyObj2()
# print(f'anons recursive: {_get_anonymous_function_attributes_recursive(obj)=}')
print('getting all anonymous functions recursively: ')
anons: dict = _get_anonymous_function_attributes_recursive(top_obj, 'top_obj')
print(f'{len(anons.keys())=}')
for k, v in anons.items():
    print()
    print(f'{k=}')
    print(f'{v=}')
    # print(k, v)
print()

# from uutils import get_anonymous_function_attributes_recursive
# get_anonymous_function_attributes_recursive(top_obj, 'top_obj', print_output=True)
# print()
# %%
"""
pip install torch
pip install learn2learn
"""
print()
import learn2learn
from torch.utils.data import DataLoader

omni = learn2learn.vision.benchmarks.get_tasksets('omniglot', root='~/data/l2l_data')
loader = DataLoader(omni, num_workers=1)
next(iter(loader))
print()

related:

Charlie Parker
  • 5,884
  • 57
  • 198
  • 323

2 Answers2

0

Looking closer to your dataset

print()
import learn2learn
from torch.utils.data import DataLoader

omni = learn2learn.vision.benchmarks.get_tasksets('omniglot', root='~/data/l2l_data')

print(omni)
BenchmarkTasksets(
  train=<learn2learn.data.task_dataset.TaskDataset object at 0x7f07cd9f8830>,
  validation=<learn2learn.data.task_dataset.TaskDataset object at 0x7f07cd9f89f0>, 
  test=<learn2learn.data.task_dataset.TaskDataset object at 0x7f07cd9f8ad0>
)

I guess you want to load from one of the partitions of the dataset, train, validataion, or test

Here one example that creates a dataloader and show the shape of the tensors of one the first sample for each of the partitions of omni

for d in [omni.train, omni.validation, omni.test]:
  dl = DataLoader(d)
  print([tuple(t.shape) for t in next(iter(dl))])
[(1, 50, 1, 28, 28), (1, 50)]
[(1, 50, 1, 28, 28), (1, 50)]
[(1, 50, 1, 28, 28), (1, 50)]
Bob
  • 13,867
  • 1
  • 5
  • 27
0

Steps

  1. Fix recursion error

In __get_anonymous_function_attributes_recursive, add these two lines:

if _is_anonymous_function(anything):
    ...
elif type(anything).__module__ == 'builtins' or type(anything).__name__ == '_memoryviewslice':  # Add these
    pass                                                                                        # two lines
else:
  1. Fix callable has no attribute '__name__'

In _is_anonymous_function, change the following line:

# return callable(f) and f.__name__ == "<lambda>"                  # Change this
return callable(f) and getattr(f, '__name__', None) == "<lambda>"  # to this
  1. Support check lambda in list

In __get_anonymous_function_attributes_recursive, add these five lines:

path_for_this_field = f'{path}.{field_name}'
__get_anonymous_function_attributes_recursive(field, path_for_this_field)
if isinstance(field, list):                                                        # Add
    fields = field                                                                 # these
    for i, field in enumerate(fields):                                             # five
        path_for_this_field = f'{path}.{field_name}[{i}]'                          # lines
        __get_anonymous_function_attributes_recursive(field, path_for_this_field)  #

Result

{'omni.test.dataset.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.dataset.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.dataset.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[0].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[0].filter.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].filter.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].filter.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[0].kshots.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].kshots.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].kshots.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[0].nways.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].nways.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[0].nways.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[1].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[1].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[1].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[2].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[2].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[2].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[3].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[3].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[3].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.test.task_transforms[4].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[4].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.test.task_transforms[4].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.dataset.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.dataset.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.dataset.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[0].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[0].filter.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].filter.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].filter.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[0].kshots.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].kshots.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].kshots.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[0].nways.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].nways.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[0].nways.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[1].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[1].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[1].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[2].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[2].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[2].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[3].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[3].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[3].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.train.task_transforms[4].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[4].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.train.task_transforms[4].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.dataset.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.dataset.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.dataset.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[0].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[0].filter.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].filter.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].filter.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[0].kshots.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].kshots.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].kshots.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[0].nways.dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].nways.dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[0].nways.dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[1].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[1].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[1].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[2].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[2].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[2].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[3].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[3].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[3].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>,
 'omni.validation.task_transforms[4].dataset.dataset.dataset.datasets[1].target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[4].dataset.dataset.dataset.datasets[1].transforms.target_transform': <function FullOmniglot.__init__.<locals>.<lambda> at 0x111111111>,
 'omni.validation.task_transforms[4].dataset.dataset.transform.transforms[2]': <function omniglot_tasksets.<locals>.<lambda> at 0x222222222>}

Possible hack

class FullOmniglotTransform:
    """
    learn2learn.vision.datasets.full_omniglot.FullOmniglot.__init__
        target_transform=lambda x: x + len(omni_background._characters))
    """
    def __init__(self, omni_background):
        self.omni_background = omni_background

    def __call__(self, x):
        return x + len(self.omni_background._characters)


class OmniglotTasksetsTransform:
    """
    learn2learn.vision.benchmarks.omniglot_benchmark.omniglot_tasksets
        lambda x: 1.0 - x,
    """
    def __init__(self, v):
        self.v = v

    def __call__(self, x):
        return self.v - x

Add these three lines:

omni = learn2learn.vision.benchmarks.get_tasksets('omniglot', root='~/data/l2l_data')
omni_background, omni_evaluation = omni.test.dataset.dataset.dataset.dataset.datasets                                    # Add these
omni_evaluation.target_transform = omni_evaluation.transforms.target_transform = FullOmniglotTransform(omni_background)  # three
omni.test.dataset.dataset.dataset.transform.transforms[2] = OmniglotTasksetsTransform(1.0)                               # lines

The above works with num_workers=1, for the proper usage of BenchmarkTasksets ("omni") with DataLoader from Bob's answer:

for d in [omni.train, omni.validation, omni.test]:
    # dl = DataLoader(d)
    dl = DataLoader(d, num_workers=1)
    print([tuple(t.shape) for t in next(iter(dl))])
aaron
  • 39,695
  • 6
  • 46
  • 102