aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libbe/util/wsgi.py176
1 files changed, 151 insertions, 25 deletions
diff --git a/libbe/util/wsgi.py b/libbe/util/wsgi.py
index fd219e1..eddf36f 100644
--- a/libbe/util/wsgi.py
+++ b/libbe/util/wsgi.py
@@ -27,8 +27,12 @@ See Also
import copy
import hashlib
import logging
+import logging.handlers
+import os
import os.path
import re
+import select
+import signal
import StringIO
import sys
import time
@@ -58,6 +62,7 @@ except ImportError:
import libbe.util.encoding
import libbe.command
import libbe.command.base
+import libbe.command.util
import libbe.storage
@@ -573,6 +578,12 @@ class ServerCommand (libbe.command.base.Command):
Use this as a base class to build commands that serve a web interface.
"""
+ _daemon_actions = ['start', 'stop']
+ _daemon_action_present_participle = {
+ 'start': 'starting',
+ 'stop': 'stopping',
+ }
+
def __init__(self, *args, **kwargs):
super(ServerCommand, self).__init__(*args, **kwargs)
self.options.extend([
@@ -584,6 +595,23 @@ class ServerCommand (libbe.command.base.Command):
help='Set host string (blank for localhost)',
arg=libbe.command.Argument(
name='host', metavar='HOST', default='localhost')),
+ libbe.command.Option(name='daemon',
+ help=('Start or stop a server daemon. Stopping requires '
+ 'a PID file'),
+ arg=libbe.command.Argument(
+ name='daemon', metavar='ACTION',
+ completion_callback=libbe.command.util.Completer(
+ self._daemon_actions))),
+ libbe.command.Option(name='pidfile', short_name='p',
+ help='Store the process id in the given path',
+ arg=libbe.command.Argument(
+ name='pidfile', metavar='FILE',
+ completion_callback=libbe.command.util.complete_path)),
+ libbe.command.Option(name='logfile',
+ help='Log to the given path (instead of stdout)',
+ arg=libbe.command.Argument(
+ name='logfile', metavar='FILE',
+ completion_callback=libbe.command.util.complete_path)),
libbe.command.Option(name='read-only', short_name='r',
help='Dissable operations that require writing'),
libbe.command.Option(name='notify', short_name='n',
@@ -604,7 +632,14 @@ class ServerCommand (libbe.command.base.Command):
])
def _run(self, **params):
- self._setup_logging()
+ if params['daemon'] not in self._daemon_actions + [None]:
+ raise libbe.command.UserError(
+ 'Invalid daemon action "{}".\nValid actions:\n {}'.format(
+ params['daemon'], self._daemon_actions))
+ self._setup_logging(params)
+ if params['daemon'] not in [None, 'start']:
+ self._manage_daemon(params)
+ return
storage = self._get_storage()
if params['read-only']:
writeable = storage.writeable
@@ -632,15 +667,22 @@ class ServerCommand (libbe.command.base.Command):
def _get_app(self, logger, storage, **kwargs):
raise NotImplementedError()
- def _setup_logging(self, log_level=logging.INFO):
- self.logger = logging.getLogger('be-{}'.format(self.name))
- self.log_level = logging.INFO
- console = logging.StreamHandler(self.stdout)
- console.setFormatter(logging.Formatter('%(message)s'))
- self.logger.addHandler(console)
+ def _setup_logging(self, params, log_level=logging.INFO):
+ self.logger = logging.getLogger('be.{}'.format(self.name))
+ self.log_level = log_level
+ if params['logfile']:
+ path = os.path.abspath(os.path.expanduser(
+ params['logfile']))
+ handler = logging.handlers.TimedRotatingFileHandler(
+ path, when='w6', interval=1, backupCount=4,
+ encoding=libbe.util.encoding.get_text_file_encoding())
+ else:
+ handler = logging.StreamHandler(self.stdout)
+ handler.setFormatter(logging.Formatter('%(message)s'))
+ self.logger.addHandler(handler)
self.logger.propagate = False
if log_level is not None:
- console.setLevel(log_level)
+ handler.setLevel(log_level)
self.logger.setLevel(log_level)
def _get_server(self, params, app):
@@ -648,12 +690,15 @@ class ServerCommand (libbe.command.base.Command):
'socket-name':params['host'],
'port':params['port'],
}
+ if params['ssl']:
+ details['protocol'] = 'HTTPS'
+ else:
+ details['protocol'] = 'HTTP'
app = BEExceptionApp(app, logger=self.logger)
app = HandlerErrorApp(app, logger=self.logger)
app = ExceptionApp(app, logger=self.logger)
- if params['ssl'] == True:
- details['protocol'] = 'HTTPS'
- if cherrypy == None:
+ if params['ssl']:
+ if cherrypy is None:
raise libbe.command.UserError(
'--ssl requires the cherrypy module')
server = cherrypy.wsgiserver.CherryPyWSGIServer(
@@ -661,8 +706,8 @@ class ServerCommand (libbe.command.base.Command):
#server.throw_errors = True
#server.show_tracebacks = True
private_key,certificate = _get_cert_filenames(
- 'be-server', logger=self.logger)
- if cherrypy.wsgiserver.ssl_builtin == None:
+ 'be-server', logger=self.logger, level=self.log_level)
+ if cherrypy.wsgiserver.ssl_builtin is None:
server.ssl_module = 'builtin'
server.ssl_private_key = private_key
server.ssl_certificate = certificate
@@ -671,27 +716,106 @@ class ServerCommand (libbe.command.base.Command):
cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
certificate=certificate, private_key=private_key))
else:
- details['protocol'] = 'HTTP'
server = wsgiref.simple_server.make_server(
params['host'], params['port'], app,
handler_class=SilentRequestHandler)
return (server, details)
+ def _daemonize(self, params):
+ signal.signal(signal.SIGTERM, self._sigterm)
+ self.logger.log(self.log_level, 'Daemonizing')
+ pid = os.fork()
+ if pid > 0:
+ os._exit(0)
+ os.setsid()
+ pid = os.fork()
+ if pid > 0:
+ os._exit(0)
+ self.logger.log(
+ self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
+
+ def _get_pidfile(self, params):
+ params['pidfile'] = os.path.abspath(os.path.expanduser(
+ params['pidfile']))
+ self.logger.log(
+ self.log_level, 'Get PID file at {}'.format(params['pidfile']))
+ if os.path.exists(params['pidfile']):
+ raise libbe.command.UserError(
+ 'PID file {} already exists'.format(params['pidfile']))
+ pid = os.getpid()
+ with open(params['pidfile'], 'w') as f: # race between exist and open
+ f.write(str(os.getpid()))
+ self.logger.log(
+ self.log_level, 'Got PID file as {}'.format(pid))
+
def _start_server(self, params, server, details):
- self.logger.log(self.log_level,
+ if params['daemon']:
+ self._daemonize(params=params)
+ if params['pidfile']:
+ self._get_pidfile(params)
+ self.logger.log(
+ self.log_level,
('Serving {protocol} on {socket-name} port {port} ...\n'
'BE repository {repo}').format(**details))
- if params['ssl']:
+ params['server stopped'] = False
+ if isinstance(server, wsgiref.simple_server.WSGIServer):
+ try:
+ server.serve_forever()
+ except select.error as e:
+ if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
+ pass
+ else:
+ raise
+ else: # CherryPy server
server.start()
- else:
- server.serve_forever()
def _stop_server(self, params, server):
+ if params['server stopped']:
+ return # already stopped, e.g. via _sigterm()
+ params['server stopped'] = True
self.logger.log(self.log_level, 'Closing server')
- if params['ssl'] == True:
+ if isinstance(server, wsgiref.simple_server.WSGIServer):
+ server.server_close()
+ else:
server.stop()
+ if params['pidfile']:
+ os.remove(params['pidfile'])
+
+ def _sigterm(self, signum, frame):
+ self.logger.log(self.log_level, 'Handling SIGTERM')
+ # extract params and server from the stack
+ f = frame
+ while f is not None and f.f_code.co_name != '_start_server':
+ f = f.f_back
+ if f is None:
+ self.logger.log(
+ self.log_level,
+ 'SIGTERM from outside _start_server(): {}'.format(
+ frame.f_code))
+ return # where did this signal come from?
+ params = f.f_locals['params']
+ server = f.f_locals['server']
+ self._stop_server(params=params, server=server)
+
+ def _manage_daemon(self, params):
+ "Daemon management (any action besides 'start')"
+ if not params['pidfile']:
+ raise libbe.command.UserError(
+ 'daemon management requires --pidfile')
+ try:
+ with open(params['pidfile'], 'r') as f:
+ pid = f.read().strip()
+ except IOError as e:
+ raise libbe.command.UserError(
+ 'could not find PID file: {}'.format(e))
+ pid = int(pid)
+ pp = self._daemon_action_present_participle[params['daemon']].title()
+ self.logger.log(
+ self.log_level, '{} daemon running on process {}'.format(pp, pid))
+ if params['daemon'] == 'stop':
+ os.kill(pid, signal.SIGTERM)
else:
- server.server_close()
+ raise NotImplementedError(params['daemon'])
def _long_help(self):
raise NotImplementedError()
@@ -908,7 +1032,8 @@ if libbe.TESTING:
# The following certificate-creation code is adapted from pyOpenSSL's
# examples.
-def _get_cert_filenames(server_name, autogenerate=True, logger=None):
+def _get_cert_filenames(server_name, autogenerate=True, logger=None,
+ level=None):
"""
Generate private key and certification filenames.
get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
@@ -918,7 +1043,7 @@ def _get_cert_filenames(server_name, autogenerate=True, logger=None):
if autogenerate:
for file in [pkey_file, cert_file]:
if not os.path.exists(file):
- _make_certs(server_name, logger)
+ _make_certs(server_name, logger=logger, level=level)
return (pkey_file, cert_file)
def _create_key_pair(type, bits):
@@ -1007,7 +1132,7 @@ def _create_certificate(req, (issuerCert, issuerKey), serial,
cert.sign(issuerKey, digest)
return cert
-def _make_certs(server_name, logger=None) :
+def _make_certs(server_name, logger=None, level=None):
"""Generate private key and certification files.
`mk_certs(server_name) -> (pkey_filename, cert_filename)`
@@ -1018,8 +1143,9 @@ def _make_certs(server_name, logger=None) :
pkey_file,cert_file = _get_cert_filenames(
server_name, autogenerate=False)
if logger != None:
- logger.log(logger._server_level,
- 'Generating certificates', pkey_file, cert_file)
+ logger.log(
+ level, 'Generating certificates {} {}'.format(
+ pkey_file, cert_file))
cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
careq = _create_cert_request(cakey, CN='Certificate Authority')
cacert = _create_certificate(