Cleaned up code

This commit is contained in:
Florian Rupp 2023-04-18 22:18:02 +02:00
parent d3eaaf3b9f
commit 03c25129c9

View file

@ -6,14 +6,14 @@ import utmp
# These defaults can be overwritten by command line arguments # These defaults can be overwritten by command line arguments
SERVER_HOST = '0.0.0.0' SERVER_HOST = '0.0.0.0'
SERVER_PORT = 9999 SERVER_PORT = 9999
FETCH_INTERVAL = 5 FETCH_INTERVAL = 15
class Session: class Session:
""" This class is used to create a Session object containing info on an SSH session, mainly for readability """ This class is used to create a Session object containing info on an SSH session, mainly for readability
Only the fields name, tty, from_, login are actually used for now """ Only the fields name, tty, from_, login are actually used for now """
def __init__(self, name, tty, from_, login, idle, jcpu, pcpu, what): def __init__(self, name : str, tty : str, from_ : str, login : str, idle=0, jcpu=0, pcpu=0, what=0):
self.name = name # Username that is logged in self.name = name # Username that is logged in
self.tty = tty # Which tty is used self.tty = tty # Which tty is used
self.from_ = from_ # remote IP address self.from_ = from_ # remote IP address
@ -32,56 +32,25 @@ class Session:
def __eq__(self, other): def __eq__(self, other):
return self.login == other.login and self.tty == other.tty and self.from_ == other.from_ return self.login == other.login and self.tty == other.tty and self.from_ == other.from_
def to_dict(self):
# maybe this will be used later
return {
'name': self.name,
'tty': self.tty,
'from_': self.from_,
'login': self.login,
'idle': self.idle,
'jcpu': self.jcpu,
'pcpu': self.pcpu,
'what': self.what
}
def to_list(self): def get_utmp_data() -> list[Session]:
return [self.name, self.tty, self.from_, self.login, self.idle, self.jcpu, self.pcpu, self.what]
def contains_user_list(user, other_user_list):
for other_user in other_user_list:
if are_equal(user, other_user):
return True
return False
def are_equal(user_list, other_user_list):
""" Two SSh sessions are equal if their name, tty, remote IP and login time are equal
The other fields change over time hence they are not used for comparison """
assert len(user_list) == len(other_user_list)
for i in range(4):
if user_list[i] != other_user_list[i]:
return False
return True
def get_utmp_data():
""" """
Returns a list of User Objects Returns a list of User Objects
The function uses the utmp library. The utmp file contains information about currently logged in users The function uses the utmp library. The utmp file contains information about ALL currently logged in users,
including local users (not SSH sessions). We filter out the local users by checking if the remote IP address
is empty.
""" """
users = [] users : list[Session] = []
with open('/var/run/utmp', 'rb') as f: with open('/var/run/utmp', 'rb') as fd:
buffer = f.read() buffer = fd.read()
for record in utmp.read(buffer): for record in utmp.read(buffer):
if record.type == utmp.UTmpRecordType.user_process: if record.type == utmp.UTmpRecordType.user_process and record.host != '':
users.append(Session(record.user, record.line, record.host, record.sec, 0, 0, 0, 0)) users.append(Session(record.user, record.line, record.host, record.sec))
return users return users
def parse_arguments(): def parse_arguments() -> None:
global FETCH_INTERVAL, SERVER_PORT, SERVER_HOST global FETCH_INTERVAL, SERVER_PORT, SERVER_HOST
@ -114,39 +83,38 @@ if __name__ == '__main__':
# Start up the server to expose the metrics. # Start up the server to expose the metrics.
prometheus_client.start_http_server(SERVER_PORT) prometheus_client.start_http_server(SERVER_PORT)
print("Started metrics server bound to {}:{}".format(SERVER_HOST, SERVER_PORT)) print("Started metrics server bound to {}:{}".format(SERVER_HOST, SERVER_PORT))
num_sessions = []
gauge_num_sessions = prometheus_client.Gauge( gauge_num_sessions = prometheus_client.Gauge(
'ssh_num_sessions', 'Number of SSH sessions', ['remote_ip']) 'ssh_num_sessions', 'Number of SSH sessions', ['remote_ip'])
# data = get_w_data()
data = get_utmp_data() session_data = get_utmp_data()
list_data = [user.to_list() for user in data] num_sessions = len(session_data)
# Initial metrics # Initial metrics
print("Connections at startup:") print("Active sessions at startup:")
for i in range(len(list_data)): for session in session_data:
gauge_num_sessions.labels(remote_ip=list_data[i][2]).inc() gauge_num_sessions.labels(remote_ip=session.from_).inc()
print("Initial connection: {}".format(list_data[i])) print("Initial connection: {}".format(session.from_))
# Generate some requests. # Generate some requests.
print("Looking for SSH connection changes at interval {}".format(FETCH_INTERVAL)) print("Looking for SSH connection changes at interval {}".format(FETCH_INTERVAL))
while True: while True:
list_old_data = list_data old_session_data = session_data
# data = get_w_data() old_num_sessions = len(old_session_data)
data = get_utmp_data()
list_data = [user.to_list() for user in data]
num_sessions = len(data)
for i in range(num_sessions): session_data = get_utmp_data()
num_sessions = len(session_data)
for maybe_new_session in session_data:
# Looking for newly found SSH sessions # Looking for newly found SSH sessions
if not contains_user_list(list_data[i], list_old_data): if not maybe_new_session in old_session_data:
print("Session connected: %s" % list_data[i]) print("Session connected: %s" % maybe_new_session.from_)
gauge_num_sessions.labels(remote_ip=list_data[i][2]).inc() gauge_num_sessions.labels(remote_ip=maybe_new_session.from_).inc()
for i in range(len(list_old_data)): for maybe_old_session in old_session_data:
# Looking for SSH sessions that no longer exist # Looking for SSH sessions that no longer exist
if not contains_user_list(list_old_data[i], list_data): if not maybe_old_session in session_data:
print("Session disconnected: %s" % list_old_data[i]) print("Session disconnected: %s" % maybe_old_session.from_)
gauge_num_sessions.labels(remote_ip=list_old_data[i][2]).dec() gauge_num_sessions.labels(remote_ip=maybe_old_session.from_).dec()
time.sleep(FETCH_INTERVAL) time.sleep(FETCH_INTERVAL)