Use post_generation hooks to populate models with some data

Factory Boy provides us with a better way of populating complicated
fields (like M2M relations) than overriding _prepare class method,
so we should be using them.
This commit is contained in:
Krzysztof Klimonda
2013-03-04 14:39:26 -08:00
parent ef818b7f82
commit 7a85f1b514
2 changed files with 44 additions and 15 deletions

View File

@@ -1,3 +1,4 @@
from .api import * from .api import *
from .forms import * from .forms import *
from .helpers import PinFactoryTest
from .views import * from .views import *

View File

@@ -1,6 +1,8 @@
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
from django.core.files import File from django.core.files import File
from django.db.models.query import QuerySet
from django.test import TestCase
import factory import factory
from taggit.models import Tag from taggit.models import Tag
@@ -18,17 +20,14 @@ class UserFactory(factory.Factory):
username = factory.Sequence(lambda n: 'user_{}'.format(n)) username = factory.Sequence(lambda n: 'user_{}'.format(n))
email = factory.Sequence(lambda n: 'user_{}@example.com'.format(n)) email = factory.Sequence(lambda n: 'user_{}@example.com'.format(n))
@classmethod @factory.post_generation(extract_prefix='password')
def _prepare(cls, create, **kwargs): def set_password(self, create, extracted, **kwargs):
password = kwargs.pop('password', None) self.set_password(extracted)
user = super(UserFactory, cls)._prepare(create, **kwargs) self.save()
user.user_permissions = Permission.objects.filter(codename__in=['add_pin', 'add_image'])
if password:
user.set_password(password)
if create:
user.save()
return user
@factory.post_generation(extract_prefix='user_permissions')
def set_user_permissions(self, create, extracted, **kwargs):
self.user_permissions = Permission.objects.filter(codename__in=['add_pin', 'add_image'])
class TagFactory(factory.Factory): class TagFactory(factory.Factory):
FACTORY_FOR = Tag FACTORY_FOR = Tag
@@ -48,8 +47,37 @@ class PinFactory(factory.Factory):
submitter = factory.SubFactory(UserFactory) submitter = factory.SubFactory(UserFactory)
image = factory.SubFactory(ImageFactory) image = factory.SubFactory(ImageFactory)
@classmethod @factory.post_generation(extract_prefix='tags')
def _prepare(cls, create, **kwargs): def add_tags(self, create, extracted, **kwargs):
pin = super(PinFactory, cls)._prepare(create, **kwargs) if isinstance(extracted, Tag):
pin.tags.add(TagFactory()) self.tags.add(extracted)
return pin elif isinstance(extracted, list):
self.tags.add(*extracted)
elif isinstance(extracted, QuerySet):
self.tags = extracted
else:
self.tags.add(TagFactory())
class PinFactoryTest(TestCase):
def test_default_tags(self):
self.assertTrue(PinFactory().tags.get(pk=1).name.startswith('tag_'))
def test_custom_tag(self):
custom = 'custom_tag'
self.assertEqual(PinFactory(tags=Tag.objects.create(name=custom)).tags.get(pk=1).name, custom)
def test_custom_tags_list(self):
tags = TagFactory.create_batch(2)
PinFactory(tags=tags)
self.assertEqual(Tag.objects.count(), 2)
def test_custom_tags_queryset(self):
TagFactory.create_batch(2)
tags = Tag.objects.all()
PinFactory(tags=tags)
self.assertEqual(Tag.objects.count(), 2)
def test_empty_tags(self):
PinFactory(tags=[])
self.assertEqual(Tag.objects.count(), 0)