0

Say for example I have some classes which all inherent from the same parent class and have the same parameters. A common example;

class Pet():
...

class Cat(Pet):
 __init__(self,name,colour):
Pet.__init__(self,name,colour)
....

class Cactus(Pet):
 __init__(self,name,colour):
Pet.__init__(self,name,colour)
....

And then say I want to instantite some type of pet later in the program based on user input. What I would think of doing at first is;

if(pet_type == 'Cat'):
 animal = Cat(name,colour)
elif(pet_type == 'Cactus'):
 animal = Cactus(name,colour)
etc...

But is there a better way that does not require an if? For example if the program was developed to include over 1000 animals which all descend from pet then it would not be feasilbe.

  • Possible duplicate of [Too many if statements](https://stackoverflow.com/questions/31748617/too-many-if-statements) – Aran-Fey Oct 05 '17 at 14:20

3 Answers3

1

Create a dictionary of allowable classes:

classes = {
    'Cat': Cat,
    'Cactus': Cactus,
}

try:
    cls = classes[pet_type]
except KeyError:
    # handle invalid pet_type here
else:
    animal = cls(name, colour)

Depending on your response to a KeyError, you may want to use a finally clause instead of an else clause, or simply use the get method of the dict, e.g.

animal = classes.get(pet_type, Pet)(name, colour)
chepner
  • 497,756
  • 71
  • 530
  • 681
1

What you look for is known as factory pattern. There is a great many ways of achieving this, ranging from explicit if-cascades as you show them to meta-class magic.

A rather straightforward way is a class-decorator:

PET_CLASSES = {}

def pet_class(cls):
    PET_CLASSES[cls.__name__.lower()] = cls


def create_pet(name):
    return PET_CLASSES[name.lower()]()
@pet_class
class Cat:
    pass


@pet_class
class Dog:
    pass


print(create_pet("dog"))
deets
  • 6,285
  • 29
  • 28
0

You can get the class by its string name with getattr()

 my_class = getattr(module, "class_name")

and then

 my_object = my_class(...)

The module object can be accessed by module = __import__("module_name")

blue_note
  • 27,712
  • 9
  • 72
  • 90