from django.db.models import ManyToManyField, Count, signals
from django.db.models.fields.related import add_lazy_relation

class RelationCardinalityException(Exception):
    pass


class MaxCardinalityManyToManyField(ManyToManyField):
    '''A ManyToManyField that constrains the maximum number of relationships in
    one or both directions.

    An upper bound can be set for the forward relationships (``max_cardinality``)
    and/or the reverse relationships (``reverse_max_cardinality``). If either is
    left undefined (or None), it defaults to unbounded. Attempting to add one or
    more relationships that would result in exceeding the bound(s) raises a
    :class:`RelationCardinalityException`.

    For symmetric relationships, ``max_cardinality`` and ``reverse_max_cardinality``
    must be equal. As a shortcut, leaving one of the two undefined defaults to
    the other, so just defining one of them is enough.

    Example::

        class Topping(models.Model):
            name = models.CharField(max_length=128, unique=True)

        class Pizza(models.Model):
            name = models.CharField(max_length=128, unique=True)
            toppings = MaxCardinalityManyToManyField(Topping,
                                                     max_cardinality=2,
                                                     reverse_max_cardinality=3)

        >>> mushrooms = Topping.objects.get_or_create(name='mushrooms')[0]
        >>> anchovies = Topping.objects.get_or_create(name='anchovies')[0]
        >>> mozzarella = Topping.objects.get_or_create(name='mozzarella')[0]

        >>> margherita = Pizza.objects.get_or_create(name='margherita')[0]
        >>> marinara = Pizza.objects.get_or_create(name='marinara')[0]
        >>> sicilian = Pizza.objects.get_or_create(name='sicilian')[0]
        >>> california = Pizza.objects.get_or_create(name='california')[0]

        >>> # try to exceed max_cardinality through 'toppings'
        >>> margherita.toppings.add(mushrooms, anchovies)
        >>> margherita.toppings.add(mozzarella)
        Traceback (most recent call last):
        ...
        RelationCardinalityException: No more pizza-topping relationships allowed for pizza.pk=1
        >>> margherita.toppings.clear()

        >>> # try to exceed max_cardinality through 'pizza_set'
        >>> for topping in mushrooms, mozzarella:
        >>> ... topping.pizza_set = [marinara, sicilian]
        >>> anchovies.pizza_set.add(sicilian)
        Traceback (most recent call last):
        ...
        RelationCardinalityException: No more pizza-topping relationships allowed for pizza.pk=3
        >>> for topping in mushrooms, mozzarella, anchovies:
        >>> ... topping.pizza_set.clear()

        >>> # try to exceed reverse_max_cardinality through 'pizza_set'.
        >>> mushrooms.pizza_set.add(margherita, marinara, sicilian)
        >>> mushrooms.pizza_set.add(california)
        Traceback (most recent call last):
        ...
        RelationCardinalityException: No more pizza-topping relationships allowed for topping.pk=1
        >>> mushrooms.pizza_set.clear()

        >>> # try to exceed reverse_max_cardinality through 'toppings'
        >>> for pizza in margherita, marinara, sicilian:
        ...     pizza.toppings = [mushrooms, mozzarella]
        >>> california.toppings.add(mushrooms)
        RelationCardinalityException: No more pizza-topping relationships allowed for topping.pk=3
    '''

    def __init__(self, to, **kwargs):
        self.max_cardinality = kwargs.pop('max_cardinality', None)
        self.reverse_max_cardinality = kwargs.pop('reverse_max_cardinality', None)
        super(MaxCardinalityManyToManyField,self).__init__(to, **kwargs)
        if self.rel.symmetrical:
            if self.reverse_max_cardinality is None:
                self.reverse_max_cardinality = self.max_cardinality
            elif self.max_cardinality is None:
                self.max_cardinality = self.reverse_max_cardinality
            elif self.max_cardinality != self.reverse_max_cardinality:
                raise ValueError('Symmetrical relationships must have equal '
                                 'forward and reverse max cardinality')

    def contribute_to_class(self, cls, name):
        super(MaxCardinalityManyToManyField, self).contribute_to_class(cls, name)
        if self.max_cardinality or self.reverse_max_cardinality:
            through = self.rel.through
            if through:
                if isinstance(through, basestring):
                    add_lazy_relation(cls, self, through,
                        lambda self, through, cls: self.__connect_through_signals(through))
                else:
                    self.__connect_through_signals(through)

    def __connect_through_signals(self, through):

        def validate_cardinalities(sender, instance, **kwargs):
            pk = instance._get_pk_val()
            # XXX: _base_manager or _default_manager ?
            exists = pk is not None and instance.__class__._base_manager.filter(pk=pk).exists()
            if not exists:
                self._validate_cardinality(getattr(instance, self.m2m_column_name()),
                                           reverse=False)
                self._validate_cardinality(getattr(instance, self.m2m_reverse_name()),
                                           reverse=True)
        signals.pre_save.connect(validate_cardinalities, sender=through, weak=False)

        def m2m_validate_cardinalities(sender, instance, action, reverse, pk_set, **kwargs):
            if action != 'pre_add' or not pk_set:
                return
            if reverse:
                self._validate_cardinality(*pk_set, reverse=False)
                self._validate_cardinality(instance._get_pk_val(),
                                           reverse=True, num_added=len(pk_set))
            else:
                self._validate_cardinality(instance._get_pk_val(),
                                           reverse=False, num_added=len(pk_set))
                self._validate_cardinality(*pk_set, reverse=True)
        signals.m2m_changed.connect(m2m_validate_cardinalities, sender=through, weak=False)

    def _validate_cardinality(self, *pks, **kwargs):
        if not kwargs['reverse'] or self.rel.symmetrical:
            field_name = self.m2m_field_name()
            threshold = self.max_cardinality
        else:
            field_name = self.m2m_reverse_field_name()
            threshold = self.reverse_max_cardinality
        if threshold is None:
            return
        threshold -= kwargs.get('num_added', 1)
        for pk, count in self._get_counts(field_name, pks).iteritems():
            if count > threshold:
                raise RelationCardinalityException('No more %s allowed for %s.pk=%s' % (
                    unicode(self.rel.through._meta.verbose_name_plural), field_name, pk))

    def _get_counts(self, field_name, pks):
        pk2count = dict.fromkeys(pks, 0)  # ensure that all pks have a count
        pk2count.update(
            self.rel.through._default_manager.values_list(field_name).
            filter(**{field_name+'__in': pk2count}).annotate(Count(field_name)))
        return pk2count
