Files
Pinry/pinry/core/tests/helpers.py
Krzysztof Klimonda 7a85f1b514 Use post_generation hooks to populate models with some data
Factory Boy provides us with a better way of populating complicated
fields (like M2M relations) than overriding _prepare class method,
so we should be using them.
2013-03-04 14:58:54 -08:00

83 lines
2.5 KiB
Python

from django.conf import settings
from django.contrib.auth.models import Permission
from django.core.files import File
from django.db.models.query import QuerySet
from django.test import TestCase
import factory
from taggit.models import Tag
from ..models import Image, Pin
from ...users.models import User
TEST_IMAGE_PATH = settings.SITE_ROOT + 'screenshot.png'
class UserFactory(factory.Factory):
FACTORY_FOR = User
username = factory.Sequence(lambda n: 'user_{}'.format(n))
email = factory.Sequence(lambda n: 'user_{}@example.com'.format(n))
@factory.post_generation(extract_prefix='password')
def set_password(self, create, extracted, **kwargs):
self.set_password(extracted)
self.save()
@factory.post_generation(extract_prefix='user_permissions')
def set_user_permissions(self, create, extracted, **kwargs):
self.user_permissions = Permission.objects.filter(codename__in=['add_pin', 'add_image'])
class TagFactory(factory.Factory):
FACTORY_FOR = Tag
name = factory.Sequence(lambda n: 'tag_{}'.format(n))
class ImageFactory(factory.Factory):
FACTORY_FOR = Image
image = factory.LazyAttribute(lambda a: File(open(TEST_IMAGE_PATH)))
class PinFactory(factory.Factory):
FACTORY_FOR = Pin
submitter = factory.SubFactory(UserFactory)
image = factory.SubFactory(ImageFactory)
@factory.post_generation(extract_prefix='tags')
def add_tags(self, create, extracted, **kwargs):
if isinstance(extracted, Tag):
self.tags.add(extracted)
elif isinstance(extracted, list):
self.tags.add(*extracted)
elif isinstance(extracted, QuerySet):
self.tags = extracted
else:
self.tags.add(TagFactory())
class PinFactoryTest(TestCase):
def test_default_tags(self):
self.assertTrue(PinFactory().tags.get(pk=1).name.startswith('tag_'))
def test_custom_tag(self):
custom = 'custom_tag'
self.assertEqual(PinFactory(tags=Tag.objects.create(name=custom)).tags.get(pk=1).name, custom)
def test_custom_tags_list(self):
tags = TagFactory.create_batch(2)
PinFactory(tags=tags)
self.assertEqual(Tag.objects.count(), 2)
def test_custom_tags_queryset(self):
TagFactory.create_batch(2)
tags = Tag.objects.all()
PinFactory(tags=tags)
self.assertEqual(Tag.objects.count(), 2)
def test_empty_tags(self):
PinFactory(tags=[])
self.assertEqual(Tag.objects.count(), 0)