aboutsummaryrefslogtreecommitdiffstats
path: root/libbe/command/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'libbe/command/base.py')
-rw-r--r--libbe/command/base.py221
1 files changed, 172 insertions, 49 deletions
diff --git a/libbe/command/base.py b/libbe/command/base.py
index 2318aa7..357940f 100644
--- a/libbe/command/base.py
+++ b/libbe/command/base.py
@@ -3,6 +3,7 @@
import codecs
import optparse
import os.path
+import StringIO
import sys
import libbe
@@ -37,14 +38,19 @@ def get_command(command_name):
raise UnknownCommand(command_name)
return cmd
-def get_command_class(module, command_name):
+def get_command_class(module=None, command_name=None):
"""Retrieves a command class from a module.
>>> import_xml_mod = get_command('import-xml')
>>> import_xml = get_command_class(import_xml_mod, 'import-xml')
>>> repr(import_xml)
"<class 'libbe.command.import_xml.Import_XML'>"
+ >>> import_xml = get_command_class(command_name='import-xml')
+ >>> repr(import_xml)
+ "<class 'libbe.command.import_xml.Import_XML'>"
"""
+ if module == None:
+ module = get_command(command_name)
try:
cname = command_name.capitalize().replace('-', '_')
cmd = getattr(module, cname)
@@ -167,7 +173,7 @@ class OptionFormatter (optparse.IndentedHelpFormatter):
return ''.join(ret[:-1])
class Command (object):
- """One-line command description.
+ """One-line command description here.
>>> c = Command()
>>> print c.help()
@@ -183,12 +189,8 @@ class Command (object):
name = 'command'
- def __init__(self, input_encoding=None, output_encoding=None,
- get_unconnected_storage=None, ui=None):
- self.input_encoding = input_encoding
- self.output_encoding = output_encoding
- self.get_unconnected_storage = get_unconnected_storage
- self.ui = ui # calling user-interface, e.g. for Help()
+ def __init__(self, ui=None):
+ self.ui = ui # calling user-interface
self.status = None
self.result = None
self.restrict_file_access = True
@@ -203,6 +205,21 @@ class Command (object):
self.args = []
def run(self, options=None, args=None):
+ self.status = 1 # in case we raise an exception
+ params = self._parse_options_args(options, args)
+ if params['help'] == True:
+ pass
+ else:
+ params.pop('help')
+ if params['complete'] != None:
+ pass
+ else:
+ params.pop('complete')
+
+ self.status = self._run(**params)
+ return self.status
+
+ def _parse_options_args(self, options=None, args=None):
if options == None:
options = {}
if args == None:
@@ -242,33 +259,11 @@ class Command (object):
if len(args) > len(self.args): # add some additional repeats
assert self.args[-1].repeatable == True, self.args[-1].name
params[self.args[-1].name].extend(args[len(self.args):])
-
- if params['help'] == True:
- pass
- else:
- params.pop('help')
- if params['complete'] != None:
- pass
- else:
- params.pop('complete')
-
- self._setup_io(self.input_encoding, self.output_encoding)
- self.status = self._run(**params)
- return self.status
+ return params
def _run(self, **kwargs):
raise NotImplementedError
- def _setup_io(self, input_encoding=None, output_encoding=None):
- if input_encoding == None:
- input_encoding = libbe.util.encoding.get_input_encoding()
- if output_encoding == None:
- output_encoding = libbe.util.encoding.get_output_encoding()
- self.stdin = codecs.getwriter(input_encoding)(sys.stdin)
- self.stdin.encoding = input_encoding
- self.stdout = codecs.getwriter(output_encoding)(sys.stdout)
- self.stdout.encoding = output_encoding
-
def help(self, *args):
return '\n\n'.join([self.usage(),
self._option_help(),
@@ -340,43 +335,171 @@ class Command (object):
raise UserError('file access restricted!\n %s not in %s'
% (path, repo))
- def _get_unconnected_storage(self):
- """Callback for use by commands that need it."""
+ def cleanup(self):
+ pass
+
+class InputOutput (object):
+ def __init__(self, stdin=None, stdout=None):
+ self.stdin = stdin
+ self.stdout = stdout
+
+ def setup_command(self, command):
+ if not hasattr(self.stdin, 'encoding'):
+ self.stdin.encoding = libbe.util.encoding.get_input_encoding()
+ if not hasattr(self.stdout, 'encoding'):
+ self.stdout.encoding = libbe.util.encoding.get_output_encoding()
+ command.stdin = self.stdin
+ command.stdin.encoding = self.stdin.encoding
+ command.stdout = self.stdout
+ command.stdout.encoding = self.stdout.encoding
+
+ def cleanup(self):
+ pass
+
+class StdInputOutput (InputOutput):
+ def __init__(self, input_encoding=None, output_encoding=None):
+ stdin,stdout = self._get_io(input_encoding, output_encoding)
+ InputOutput.__init__(self, stdin, stdout)
+
+ def _get_io(self, input_encoding=None, output_encoding=None):
+ if input_encoding == None:
+ input_encoding = libbe.util.encoding.get_input_encoding()
+ if output_encoding == None:
+ output_encoding = libbe.util.encoding.get_output_encoding()
+ stdin = codecs.getwriter(input_encoding)(sys.stdin)
+ stdin.encoding = input_encoding
+ stdout = codecs.getwriter(output_encoding)(sys.stdout)
+ stdout.encoding = output_encoding
+ return (stdin, stdout)
+
+class StringInputOutput (InputOutput):
+ """
+ >>> s = StringInputOutput()
+ >>> s.set_stdin('hello')
+ >>> s.stdin.read()
+ 'hello'
+ >>> s.stdin.read()
+ >>> print >> s.stdout, 'goodbye'
+ >>> s.get_stdout()
+ 'goodbye\n'
+ >>> s.get_stdout()
+ ''
+
+ Also works with unicode strings
+
+ >>> s.set_stdin(u'hello')
+ >>> s.stdin.read()
+ u'hello'
+ >>> print >> s.stdout, u'goodbye'
+ >>> s.get_stdout()
+ u'goodbye\n'
+ """
+ def __init__(self):
+ stdin = StringIO.StringIO()
+ stdin.encoding = 'utf-8'
+ stdout = StringIO.StringIO()
+ stdout.encoding = 'utf-8'
+ InputOutput.__init__(self, stdin, stdout)
+
+ def set_stdin(self, stdin_string):
+ self.stdin = StringIO.StringIO(stdin_string)
+
+ def get_stdout(self):
+ ret = self.stdout.getvalue()
+ self.stdout = StringIO.StringIO() # clear stdout for next read
+ self.stdin.encoding = 'utf-8'
+ return ret
+
+class UnconnectedStorageGetter (object):
+ def __init__(self, location):
+ self.location = location
+
+ def __call__(self):
+ return libbe.storage.get_storage(self.location)
+
+class StorageCallbacks (object):
+ def __init__(self, location=None):
+ if location == None:
+ location = '.'
+ self.location = location
+ self._get_unconnected_storage = UnconnectedStorageGetter(location)
+
+ def setup_command(self, command):
+ command._get_unconnected_storage = self.get_unconnected_storage
+ command._get_storage = self.get_storage
+ command._get_bugdir = self.get_bugdir
+
+ def get_unconnected_storage(self):
+ """
+ Callback for use by commands that need it.
+
+ The returned Storage instance is may actually be connected,
+ but commands that make use of the returned value should only
+ make use of non-connected Storage methods. This is mainly
+ intended for the init command, which calls Storage.init().
+ """
if not hasattr(self, '_unconnected_storage'):
- if self.get_unconnected_storage == None:
+ if self._get_unconnected_storage == None:
raise NotImplementedError
- self._unconnected_storage = self.get_unconnected_storage()
+ self._unconnected_storage = self._get_unconnected_storage()
return self._unconnected_storage
- def _get_storage(self):
- """
- Callback for use by commands that need it.
+ def set_unconnected_storage(self, unconnected_storage):
+ self._unconnected_storage = unconnected_storage
- Note that with the current implementation,
- _get_unconnected_storage() will not work after this method
- runs, but that shouldn't be an issue for any command I can
- think of...
- """
+ def get_storage(self):
+ """Callback for use by commands that need it."""
if not hasattr(self, '_storage'):
- self._storage = self._get_unconnected_storage()
+ self._storage = self.get_unconnected_storage()
self._storage.connect()
version = self._storage.storage_version()
if version != libbe.storage.STORAGE_VERSION:
raise libbe.storage.InvalidStorageVersion(version)
return self._storage
- def _get_bugdir(self):
+ def set_storage(self, storage):
+ self._storage = storage
+
+ def get_bugdir(self):
"""Callback for use by commands that need it."""
if not hasattr(self, '_bugdir'):
- self._bugdir = libbe.bugdir.BugDir(self._get_storage(), from_storage=True)
+ self._bugdir = libbe.bugdir.BugDir(self.get_storage(),
+ from_storage=True)
return self._bugdir
+ def set_bugdir(self, bugdir):
+ self._bugdir = bugdir
+
+ def cleanup(self):
+ if hasattr(self, '_storage'):
+ self._storage.disconnect()
+
+class UserInterface (object):
+ def __init__(self, io=None, location=None):
+ if io == None:
+ io = StringInputOutput()
+ self.io = io
+ self.storage_callbacks = StorageCallbacks(location)
+ self.restrict_file_access = True
+
+ def help(self):
+ raise NotImplementedError
+
+ def run(self, command, options=None, args=None):
+ command.ui = self
+ self.io.setup_command(command)
+ self.storage_callbacks.setup_command(command)
+ command.restrict_file_access = self.restrict_file_access
+ command._get_user_id = self._get_user_id
+ return command.run(options, args)
+
def _get_user_id(self):
"""Callback for use by commands that need it."""
if not hasattr(self, '_user_id'):
- self._user_id = libbe.ui.util.user.get_user_id(self._get_storage())
+ self._user_id = libbe.ui.util.user.get_user_id(
+ self.storage_callbacks.get_storage())
return self._user_id
def cleanup(self):
- if hasattr(self, '_storage'):
- self._storage.disconnect()
+ self.storage_callbacks.cleanup()
+ self.io.cleanup()