diff options
Diffstat (limited to 'libbe/storage')
-rw-r--r-- | libbe/storage/base.py | 21 | ||||
-rw-r--r-- | libbe/storage/vcs/base.py | 2 |
2 files changed, 18 insertions, 5 deletions
diff --git a/libbe/storage/base.py b/libbe/storage/base.py index ffde475..1c711fa 100644 --- a/libbe/storage/base.py +++ b/libbe/storage/base.py @@ -139,6 +139,7 @@ class Storage (object): self._writeable = True # hard limit (backend choice) self.versioned = False self.can_init = True + self.connected = False def __str__(self): return '<%s %s %s>' % (self.__class__.__name__, id(self), self.repo) @@ -190,6 +191,7 @@ class Storage (object): if self.is_readable() == False: raise NotReadable('Cannot connect to unreadable storage.') self._connect() + self.connected = True def _connect(self): try: @@ -204,6 +206,12 @@ class Storage (object): """Close the connection to the repository.""" if self.is_writeable() == False: return + if self.connected == False: + return + self._disconnect() + self.connected = False + + def _disconnect(self): f = open(os.path.join(self.repo, 'repo.pkl'), 'wb') pickle.dump(dict((k,v._objects_to_ids()) for k,v in self._data.items()), f, -1) @@ -342,10 +350,7 @@ class VersionedStorage (Storage): for t in d] f.close() - def disconnect(self): - """Close the connection to the repository.""" - if self.is_writeable() == False: - return + def _disconnect(self): f = open(os.path.join(self.repo, 'repo.pkl'), 'wb') pickle.dump([dict((k,v._objects_to_ids()) for k,v in t.items()) for t in self._data], f, -1) @@ -478,6 +483,14 @@ if TESTING == True: """Should connect after initialization.""" self.s.connect() + class Storage_connect_disconnect_TestCase (StorageTestCase): + """Test cases for Storage.connect and .disconnect methods.""" + + def test_multiple_disconnects(self): + """Should be able to call .disconnect multiple times.""" + self.s.disconnect() + self.s.disconnect() + class Storage_add_remove_TestCase (StorageTestCase): """Test cases for Storage.add, .remove, and .recursive_remove methods.""" diff --git a/libbe/storage/vcs/base.py b/libbe/storage/vcs/base.py index b47ed2f..99f43f3 100644 --- a/libbe/storage/vcs/base.py +++ b/libbe/storage/vcs/base.py @@ -683,7 +683,7 @@ os.listdir(self.get_path("bugs")): self._cached_path_id.connect() self.check_storage_version() - def disconnect(self): + def _disconnect(self): self._cached_path_id.disconnect() def _add_path(self, path, directory=False): |