diff --git a/setup.py b/setup.py index 562b643..0bf1581 100644 --- a/setup.py +++ b/setup.py @@ -36,4 +36,3 @@ setup( ], zip_safe=False, ) - diff --git a/snappass/main.py b/snappass/main.py index 9db002d..a7bee3e 100644 --- a/snappass/main.py +++ b/snappass/main.py @@ -3,7 +3,7 @@ import sys import uuid import redis -from redis.exceptions import ConnectionError +from redis.exceptions import ConnectionError, TimeoutError from flask import abort, Flask, render_template, request @@ -23,27 +23,31 @@ time_conversion = { 'hour': 3600 } -def check_redis_alive(force_exit=False): - try: - redis_client.ping() - except (ConnectionError, TimeoutError) as e: - if force_exit: + +def check_redis_alive(fn): + def inner(*args, **kwargs): + try: + redis_client.ping() + except (ConnectionError, TimeoutError) as e: print(e) - sys.exit(0) - else: - abort(500) + if fn.__name__ == "main": + sys.exit(0) + else: + return abort(500) + return fn(*args, **kwargs) + return inner +@check_redis_alive def set_password(password, ttl): - check_redis_alive() key = uuid.uuid4().hex redis_client.set(key, password) redis_client.expire(key, ttl) return key +@check_redis_alive def get_password(key): - check_redis_alive() password = redis_client.get(key) if password is not None: password = password.decode('utf-8') @@ -96,8 +100,8 @@ def show_password(password_key): return render_template('password.html', password=password) +@check_redis_alive def main(): - check_redis_alive(force_exit=True) app.run(host='0.0.0.0', debug=True)