diff options
-rw-r--r-- | libbe/util/wsgi.py | 176 |
1 files changed, 151 insertions, 25 deletions
diff --git a/libbe/util/wsgi.py b/libbe/util/wsgi.py index fd219e1..eddf36f 100644 --- a/libbe/util/wsgi.py +++ b/libbe/util/wsgi.py @@ -27,8 +27,12 @@ See Also import copy import hashlib import logging +import logging.handlers +import os import os.path import re +import select +import signal import StringIO import sys import time @@ -58,6 +62,7 @@ except ImportError: import libbe.util.encoding import libbe.command import libbe.command.base +import libbe.command.util import libbe.storage @@ -573,6 +578,12 @@ class ServerCommand (libbe.command.base.Command): Use this as a base class to build commands that serve a web interface. """ + _daemon_actions = ['start', 'stop'] + _daemon_action_present_participle = { + 'start': 'starting', + 'stop': 'stopping', + } + def __init__(self, *args, **kwargs): super(ServerCommand, self).__init__(*args, **kwargs) self.options.extend([ @@ -584,6 +595,23 @@ class ServerCommand (libbe.command.base.Command): help='Set host string (blank for localhost)', arg=libbe.command.Argument( name='host', metavar='HOST', default='localhost')), + libbe.command.Option(name='daemon', + help=('Start or stop a server daemon. Stopping requires ' + 'a PID file'), + arg=libbe.command.Argument( + name='daemon', metavar='ACTION', + completion_callback=libbe.command.util.Completer( + self._daemon_actions))), + libbe.command.Option(name='pidfile', short_name='p', + help='Store the process id in the given path', + arg=libbe.command.Argument( + name='pidfile', metavar='FILE', + completion_callback=libbe.command.util.complete_path)), + libbe.command.Option(name='logfile', + help='Log to the given path (instead of stdout)', + arg=libbe.command.Argument( + name='logfile', metavar='FILE', + completion_callback=libbe.command.util.complete_path)), libbe.command.Option(name='read-only', short_name='r', help='Dissable operations that require writing'), libbe.command.Option(name='notify', short_name='n', @@ -604,7 +632,14 @@ class ServerCommand (libbe.command.base.Command): ]) def _run(self, **params): - self._setup_logging() + if params['daemon'] not in self._daemon_actions + [None]: + raise libbe.command.UserError( + 'Invalid daemon action "{}".\nValid actions:\n {}'.format( + params['daemon'], self._daemon_actions)) + self._setup_logging(params) + if params['daemon'] not in [None, 'start']: + self._manage_daemon(params) + return storage = self._get_storage() if params['read-only']: writeable = storage.writeable @@ -632,15 +667,22 @@ class ServerCommand (libbe.command.base.Command): def _get_app(self, logger, storage, **kwargs): raise NotImplementedError() - def _setup_logging(self, log_level=logging.INFO): - self.logger = logging.getLogger('be-{}'.format(self.name)) - self.log_level = logging.INFO - console = logging.StreamHandler(self.stdout) - console.setFormatter(logging.Formatter('%(message)s')) - self.logger.addHandler(console) + def _setup_logging(self, params, log_level=logging.INFO): + self.logger = logging.getLogger('be.{}'.format(self.name)) + self.log_level = log_level + if params['logfile']: + path = os.path.abspath(os.path.expanduser( + params['logfile'])) + handler = logging.handlers.TimedRotatingFileHandler( + path, when='w6', interval=1, backupCount=4, + encoding=libbe.util.encoding.get_text_file_encoding()) + else: + handler = logging.StreamHandler(self.stdout) + handler.setFormatter(logging.Formatter('%(message)s')) + self.logger.addHandler(handler) self.logger.propagate = False if log_level is not None: - console.setLevel(log_level) + handler.setLevel(log_level) self.logger.setLevel(log_level) def _get_server(self, params, app): @@ -648,12 +690,15 @@ class ServerCommand (libbe.command.base.Command): 'socket-name':params['host'], 'port':params['port'], } + if params['ssl']: + details['protocol'] = 'HTTPS' + else: + details['protocol'] = 'HTTP' app = BEExceptionApp(app, logger=self.logger) app = HandlerErrorApp(app, logger=self.logger) app = ExceptionApp(app, logger=self.logger) - if params['ssl'] == True: - details['protocol'] = 'HTTPS' - if cherrypy == None: + if params['ssl']: + if cherrypy is None: raise libbe.command.UserError( '--ssl requires the cherrypy module') server = cherrypy.wsgiserver.CherryPyWSGIServer( @@ -661,8 +706,8 @@ class ServerCommand (libbe.command.base.Command): #server.throw_errors = True #server.show_tracebacks = True private_key,certificate = _get_cert_filenames( - 'be-server', logger=self.logger) - if cherrypy.wsgiserver.ssl_builtin == None: + 'be-server', logger=self.logger, level=self.log_level) + if cherrypy.wsgiserver.ssl_builtin is None: server.ssl_module = 'builtin' server.ssl_private_key = private_key server.ssl_certificate = certificate @@ -671,27 +716,106 @@ class ServerCommand (libbe.command.base.Command): cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter( certificate=certificate, private_key=private_key)) else: - details['protocol'] = 'HTTP' server = wsgiref.simple_server.make_server( params['host'], params['port'], app, handler_class=SilentRequestHandler) return (server, details) + def _daemonize(self, params): + signal.signal(signal.SIGTERM, self._sigterm) + self.logger.log(self.log_level, 'Daemonizing') + pid = os.fork() + if pid > 0: + os._exit(0) + os.setsid() + pid = os.fork() + if pid > 0: + os._exit(0) + self.logger.log( + self.log_level, 'Daemonized with PID {}'.format(os.getpid())) + + def _get_pidfile(self, params): + params['pidfile'] = os.path.abspath(os.path.expanduser( + params['pidfile'])) + self.logger.log( + self.log_level, 'Get PID file at {}'.format(params['pidfile'])) + if os.path.exists(params['pidfile']): + raise libbe.command.UserError( + 'PID file {} already exists'.format(params['pidfile'])) + pid = os.getpid() + with open(params['pidfile'], 'w') as f: # race between exist and open + f.write(str(os.getpid())) + self.logger.log( + self.log_level, 'Got PID file as {}'.format(pid)) + def _start_server(self, params, server, details): - self.logger.log(self.log_level, + if params['daemon']: + self._daemonize(params=params) + if params['pidfile']: + self._get_pidfile(params) + self.logger.log( + self.log_level, ('Serving {protocol} on {socket-name} port {port} ...\n' 'BE repository {repo}').format(**details)) - if params['ssl']: + params['server stopped'] = False + if isinstance(server, wsgiref.simple_server.WSGIServer): + try: + server.serve_forever() + except select.error as e: + if len(e.args) == 2 and e.args[1] == 'Interrupted system call': + pass + else: + raise + else: # CherryPy server server.start() - else: - server.serve_forever() def _stop_server(self, params, server): + if params['server stopped']: + return # already stopped, e.g. via _sigterm() + params['server stopped'] = True self.logger.log(self.log_level, 'Closing server') - if params['ssl'] == True: + if isinstance(server, wsgiref.simple_server.WSGIServer): + server.server_close() + else: server.stop() + if params['pidfile']: + os.remove(params['pidfile']) + + def _sigterm(self, signum, frame): + self.logger.log(self.log_level, 'Handling SIGTERM') + # extract params and server from the stack + f = frame + while f is not None and f.f_code.co_name != '_start_server': + f = f.f_back + if f is None: + self.logger.log( + self.log_level, + 'SIGTERM from outside _start_server(): {}'.format( + frame.f_code)) + return # where did this signal come from? + params = f.f_locals['params'] + server = f.f_locals['server'] + self._stop_server(params=params, server=server) + + def _manage_daemon(self, params): + "Daemon management (any action besides 'start')" + if not params['pidfile']: + raise libbe.command.UserError( + 'daemon management requires --pidfile') + try: + with open(params['pidfile'], 'r') as f: + pid = f.read().strip() + except IOError as e: + raise libbe.command.UserError( + 'could not find PID file: {}'.format(e)) + pid = int(pid) + pp = self._daemon_action_present_participle[params['daemon']].title() + self.logger.log( + self.log_level, '{} daemon running on process {}'.format(pp, pid)) + if params['daemon'] == 'stop': + os.kill(pid, signal.SIGTERM) else: - server.server_close() + raise NotImplementedError(params['daemon']) def _long_help(self): raise NotImplementedError() @@ -908,7 +1032,8 @@ if libbe.TESTING: # The following certificate-creation code is adapted from pyOpenSSL's # examples. -def _get_cert_filenames(server_name, autogenerate=True, logger=None): +def _get_cert_filenames(server_name, autogenerate=True, logger=None, + level=None): """ Generate private key and certification filenames. get_cert_filenames(server_name) -> (pkey_filename, cert_filename) @@ -918,7 +1043,7 @@ def _get_cert_filenames(server_name, autogenerate=True, logger=None): if autogenerate: for file in [pkey_file, cert_file]: if not os.path.exists(file): - _make_certs(server_name, logger) + _make_certs(server_name, logger=logger, level=level) return (pkey_file, cert_file) def _create_key_pair(type, bits): @@ -1007,7 +1132,7 @@ def _create_certificate(req, (issuerCert, issuerKey), serial, cert.sign(issuerKey, digest) return cert -def _make_certs(server_name, logger=None) : +def _make_certs(server_name, logger=None, level=None): """Generate private key and certification files. `mk_certs(server_name) -> (pkey_filename, cert_filename)` @@ -1018,8 +1143,9 @@ def _make_certs(server_name, logger=None) : pkey_file,cert_file = _get_cert_filenames( server_name, autogenerate=False) if logger != None: - logger.log(logger._server_level, - 'Generating certificates', pkey_file, cert_file) + logger.log( + level, 'Generating certificates {} {}'.format( + pkey_file, cert_file)) cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024) careq = _create_cert_request(cakey, CN='Certificate Authority') cacert = _create_certificate( |