""" Tools for persisting object graphs with Django ORM. These classes make some assumptions about your datamodel: Object graph does not have cyclic SQL insertion order dependencies. Objects with pk=None are being inserted, all other pk values indicate an update. """ import threading from django.db.models.base import Model from django.db.models.query import QuerySet from django.db.models.fields.related import ForeignKey def is_managed(item): """Return True if item is managed by Django ORM.""" return isinstance(item, Model) def is_persistent(item): """Return True if an item has already been saved.""" return item.pk is not None class GraphError(Exception): pass class Collection(object): """ Allows iterables to be persisted in an object graph in an easier way. """ @classmethod def set_property(cls, model, attr, parent_attr, set_attr=None): """ Call this method with a model class as the 1st argument to create a property function which returns a Collection object. """ private_attr = '_' + attr def _get(self): val = getattr(self, private_attr, None) if val is None: val = cls(self, parent_attr, set_attr=set_attr) setattr(self, private_attr, val) return val def _set(self, val): setattr(self, private_attr, val) setattr(model, attr, property(_get, _set)) collection_attrs = getattr(model, '_collection_attrs', None) if collection_attrs is None: collection_attrs = [] setattr(model, '_collection_attrs', collection_attrs) collection_attrs.append((attr, private_attr)) def __init__(self, parent, parent_attr, set_attr=None, list_cls=None): self.parent = parent self.parent_attr = parent_attr self.set_attr = set_attr if list_cls is None: list_cls = list self.list_cls = list_cls # Holds list of items if parent is not persisted self._item_list = None # Holds set of items that have been accessed self._accessed_items = set() def __iter__(self): for item in self._items: self._accessed_items.add(item) yield item def __getitem__(self, k): if isinstance(k, slice): def _gen(): items = self._items.__getitem__(k) for item in items: self._accessed_items.add(item) yield item return _gen() else: item = self._items.__getitem__(k) self._accessed_items.add(item) return item def __len__(self): items = self._items if isinstance(items, QuerySet): return items.count() else: return len(items) def _get_accessed_items(self): if is_persistent(self.parent) is True: # Return all items that have been accessed, # but only if they still belong in the collection. accessed_items = set() for item in self._accessed_items: parent = getattr(item, self.parent_attr, None) if parent is None: continue if parent.pk != self.parent.pk: continue accessed_items.add(item) return accessed_items else: return self._items accessed_items = property(_get_accessed_items) def _get_items(self): """Returns all items in the collection.""" if is_persistent(self.parent) is True: return self.get_item_query() else: if self._item_list is None: self._item_list = self.list_cls() return self._item_list _items = property(_get_items) def add(self, item): """Add an item to the collection.""" if is_managed(item) is False: # Bad things will happen raise GraphError('Cannot add unmanaged object.') setattr(item, self.parent_attr, self.parent) if is_persistent(self.parent) is True: item.save() else: found = False for existing_item in self._items: if existing_item is item: found = True break if found is False: self._items.append(item) def remove(self, item, delete=False): """Remove an item from the collection.""" if delete is False: setattr(item, self.parent_attr, None) if is_persistent(self.parent) is True: if delete is True: item.delete() else: item.save() else: if item in self._items: self._items.remove(item) def clear(self, delete=False): """Remove all items from the collection.""" to_remove = list(self._items) for item in to_remove: self.remove(item, delete=delete) def update(self, new): """[].update""" for item in new: self.add(item) def get_item_query(self): """Override this method to use a custom QuerySet.""" return getattr(self.parent, self.set_attr).all() class DependencyList(object): """Lists all parents for a child where parent == dependency.""" __slots__ = ['obj', 'deps'] def __init__(self, obj=None, deps=None): self.obj = obj self.deps = deps class Dependency(object): """Signifies a dependency between two objects in the graph.""" __slots__ = ['parent', 'field', 'level'] def __init__(self, parent=None, field=None, level=None): self.parent = parent self.field = field self.level = level class GraphSaver(object): """Provides methods to simplify the persistence of large object graphs.""" def _add_dep(self, parent, child, field, deps, level): """Adds a dependency to the list.""" dep_list = self._get_dep_list(child, deps) if dep_list is None: raise GraphError('Dependency list not found!') for dep in dep_list.deps: if dep.parent is parent: raise GraphError('Circular dependency detected.') dep_list.deps.append(Dependency( parent=parent, field=field, level=level)) def _get_dep_list(self, obj, deps): """Returns a list that dependencies can be added to, or None""" dep_key = id(obj) return deps.get(dep_key, None) def _init_dep_list(self, obj, deps): """Returns a list that dependencies can be added to.""" dep_key = id(obj) dep_list = deps.get(dep_key, None) dep_list = self._get_dep_list(obj, deps) if dep_list is None: dep_list = DependencyList(obj=obj, deps=[]) deps[dep_key] = dep_list return dep_list def _build_deps(self, parent, deps, update=True, level=1): if self._get_dep_list(parent, deps) is not None: # This object has already had it's dependencies added! return if (is_persistent(parent) is False) or (update is True): # Makes sure the parent obj shows # up in the dependency list, so that # it gets saved! self._init_dep_list(parent, deps) for name in parent._meta.get_all_field_names(): field_info = parent._meta.get_field_by_name(name) if field_info[2] is False: # This is a magic reverse reference to # some other related field. Ignore it. continue field = field_info[0] if isinstance(field, ForeignKey): child = getattr(parent, name) if child is not None: # This is a dependency! # Recursively add to deps. self._build_deps(child, deps, update=update, level=level + 1) self._add_dep(parent, child, field, deps, level) # Save items in any collections self._build_collection_deps(parent, deps, update=update, level=level) def _build_collection_deps(self, parent, deps, update=True, level=1): """Add dependencies from collection objects.""" collection_attrs = getattr(parent.__class__, '_collection_attrs', []) for collection_attr in collection_attrs: collection = getattr(parent, collection_attr[0], None) if collection is not None: for item in collection.accessed_items: self._build_deps(item, deps, update=update, level=level) def _save_deps(self, deps): # Group dependencies by their level dep_levels = {} for dep_list in deps.itervalues(): max_level = 0 for dep in dep_list.deps: if dep.level > max_level: max_level = dep.level level_list = dep_levels.get(max_level, None) if level_list is None: level_list = [] dep_levels[max_level] = level_list level_list.append(dep_list) self._save_by_level(dep_levels) def _save_by_level(self, levels): # Save children by their dependency level level_keys = levels.keys() level_keys.sort() level_keys.reverse() for level in level_keys: for dep_list in levels[level]: self._save_dep_list(dep_list) def _save_dep_list(self, dep_list): # All children in the list should be the same, # so we only need to save the 1st one. child = dep_list.obj child.save() for dep in dep_list.deps: setattr(dep.parent, dep.field.name, child) def save_many(self, items, update=True): """Save multiple object graphs.""" deps = {} for item in items: if is_managed(item) is False: raise GraphError('Cannot save unmanaged item.') self._build_deps(item, deps, level=1, update=update) self._save_deps(deps) def save(self, item, update=True): """Save an object graph.""" return self.save_many((item,), update=update) class Session(object): """ Modifies QuerySets to return consistent object graphs. This will modify query behavior for all code executed during a 'with' statement. """ _session_objs = threading.local() def __init__(self): self._entered = False self._teardown_on_exit = False def __enter__(self, *args, **kwargs): """Setup session system.""" if self._entered is True: raise GraphError('Cannot re-enter session!') if self._is_setup() is False: self._setup() self._entered = True def __exit__(self, *args, **kwargs): if self._teardown_on_exit is True: self._teardown() def _is_setup(self): return getattr(self._session_objs, '_setup', False) def _setup(self): self._clear() self._patch_query_set() self._patch_model_base() self._session_objs._setup = True self._teardown_on_exit = True def _teardown(self): self._clear() self._session_objs._setup = False def _clear(self): """Clear all existing session objects.""" self._session_objs.objs = {} def _patch_query_set(self): """Monkey patches QuerySet to return cached objects.""" if getattr(QuerySet, '_session_enabled', False) is True: return f = QuerySet.iterator def _gen(q, *args, **kwargs): itr = f(q, *args, **kwargs) for item in itr: yield self.add(item) def _s_iterator(q, *args, **kwargs): if self._is_setup(): return _gen(q, *args, **kwargs) else: return f(q, *args, **kwargs) QuerySet.iterator = _s_iterator QuerySet._session_enabled = True def _patch_model_base(self): """Monkey patches Model to added saved items to session.""" if getattr(Model, '_session_enabled', False) is True: return f = Model.save def _s_save(item, *args, **kwargs): if self._is_setup(): if is_persistent(item) is True: # updating this object if ('force_insert' not in kwargs) and ('force_update' not in kwargs): kwargs['force_update'] = True self.add_with_dup_check(item) f(item, *args, **kwargs) else: # inserting this object if ('force_insert' not in kwargs) and ('force_update' not in kwargs): kwargs['force_insert'] = True f(item, *args, **kwargs) self.add_with_dup_check(item) else: return f(item, *args, **kwargs) Model.save = _s_save Model._session_enabled = True def _obj_key(self, cls, pk): """Returns session key for an object.""" cls_key = getattr(cls, '_cls_key', None) if cls_key is None: cls_key = cls.__module__ + '.' + cls.__name__ cls._cls_key = cls_key return '%s:%s' % (cls_key, pk) def add(self, item): """ Adds an object to the session. Added object is returned. If item's key matches existing key, existing object is returned. """ if is_persistent(item) is False: raise GraphError('Cannot add unpersisted object.') key = self._obj_key(item.__class__, item.pk) existing = self._session_objs.objs.get(key, None) if existing is None: self._session_objs.objs[key] = item return item else: return existing def add_with_dup_check(self, item): """Add a persistent object to the session.""" if self.add(item) is not item: raise GraphError('Instance with identical id already exists in session.')