From 55b95be58bbef0f4c5d496f010c09cf3d3169bf1 Mon Sep 17 00:00:00 2001 From: Joel Rosdahl Date: Sun, 4 Dec 2011 22:22:48 +0100 Subject: Optionally store persistent state (currently channel topic and key) --- miniircd | 123 ++++++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 40 deletions(-) diff --git a/miniircd b/miniircd index c579bf1..c740315 100755 --- a/miniircd +++ b/miniircd @@ -23,11 +23,12 @@ VERSION = "0.3" import os +import re import select import socket import string import sys -import re +import tempfile import time from datetime import datetime from optparse import OptionParser @@ -43,17 +44,60 @@ class Channel(object): self.server = server self.name = name self.members = set() - self.topic = "" - self.key = None + self._topic = "" + self._key = None + if self.server.statedir: + self._state_path = "%s/%s" % ( + self.server.statedir, + name.replace("_", "__").replace("/", "_")) + self._read_state() + else: + self._state_path = None def add_member(self, client): self.members.add(client) + def get_topic(self): + return self._topic + + def set_topic(self, value): + self._topic = value + self._write_state() + + topic = property(get_topic, set_topic) + + def get_key(self): + return self._key + + def set_key(self, value): + self._key = value + self._write_state() + + key = property(get_key, set_key) + def remove_client(self, client): self.members.discard(client) if not self.members: self.server.remove_channel(self) + def _read_state(self): + if not (self._state_path and os.path.exists(self._state_path)): + return + data = {} + exec(open(self._state_path), {}, data) + self._topic = data.get("topic", "") + self._key = data.get("key") + + def _write_state(self): + if not self._state_path: + return + (fd, path) = tempfile.mkstemp(dir=os.path.dirname(self._state_path)) + fp = os.fdopen(fd, "w") + fp.write("topic = %r\n" % self.topic) + fp.write("key = %r\n" % self.key) + fp.close() + os.rename(path, self._state_path) + class Client(object): __linesep_regexp = re.compile(r"\r?\n") @@ -195,21 +239,19 @@ class Client(object): keys = [] keys.extend((len(channelnames) - len(keys)) * [None]) for (i, channelname) in enumerate(channelnames): - if channelname in self.channels: + if irc_lower(channelname) in self.channels: continue if not valid_channel_re.match(channelname): self.reply_403(channelname) continue channel = server.get_channel(channelname) - if channel \ - and channel.key is not None \ - and channel.key != keys[i]: + if channel.key is not None and channel.key != keys[i]: self.reply( "475 %s %s :Cannot join channel (+k) - bad key" % (self.nickname, channelname)) continue - server.add_member_to_channel(self, channelname) - channel = server.get_channel(channelname) + channel.add_member(self) + self.channels[irc_lower(channelname)] = channel self.message_channel(channel, "JOIN", channelname, True) self.channel_log(channel, "joined", meta=True) if channel.topic: @@ -232,9 +274,8 @@ class Client(object): else: channels = [] for channelname in arguments[0].split(","): - channel = server.get_channel(channelname) - if channel: - channels.append(channel) + if server.has_channel(channelname): + channels.append(server.get_channel(channelname)) channels.sort(key=lambda x: x.name) for channel in channels: self.reply("322 %s %s %d :%s" @@ -247,8 +288,8 @@ class Client(object): self.reply_461("MODE") return targetname = arguments[0] - channel = server.get_channel(targetname) - if channel: + if server.has_channel(targetname): + channel = server.get_channel(targetname) if len(arguments) < 2: if channel.key: modes = "+k" @@ -340,15 +381,14 @@ class Client(object): if client: client.message(":%s %s %s :%s" % (self.prefix, command, targetname, message)) - else: + elif server.has_channel(targetname): channel = server.get_channel(targetname) - if channel: - self.message_channel( - channel, command, "%s :%s" % (channel.name, message)) - self.channel_log(channel, message) - else: - self.reply("401 %s %s :No such nick/channel" - % (self.nickname, targetname)) + self.message_channel( + channel, command, "%s :%s" % (channel.name, message)) + self.channel_log(channel, message) + else: + self.reply("401 %s %s :No such nick/channel" + % (self.nickname, targetname)) def part_handler(): if len(arguments) < 1: @@ -394,8 +434,8 @@ class Client(object): self.reply_461("TOPIC") return channelname = arguments[0] - if channelname in self.channels: - channel = server.get_channel(channelname) + channel = self.channels.get(irc_lower(channelname)) + if channel: if len(arguments) > 1: newtopic = arguments[1] channel.topic = newtopic @@ -427,8 +467,8 @@ class Client(object): if len(arguments) < 1: return targetname = arguments[0] - channel = server.get_channel(targetname) - if channel: + if server.has_channel(targetname): + channel = server.get_channel(targetname) for member in channel.members: self.reply("352 %s %s %s %s %s %s H :0 %s" % (self.nickname, targetname, member.user, @@ -582,12 +622,15 @@ class Server(object): self.verbose = options.verbose self.debug = options.debug self.logdir = options.logdir + self.statedir = options.statedir self.name = socket.getfqdn()[:63] # Server name limit from the RFC. self.channels = {} # irc_lower(Channel name) --> Channel instance. self.clients = {} # Socket --> Client instance. self.nicknames = {} # irc_lower(Nickname) --> Client instance. if self.logdir: create_directory(self.logdir) + if self.statedir: + create_directory(self.statedir) def daemonize(self): try: @@ -614,8 +657,16 @@ class Server(object): def get_client(self, nickname): return self.nicknames.get(irc_lower(nickname)) + def has_channel(self, name): + return irc_lower(name) in self.channels + def get_channel(self, channelname): - return self.channels.get(irc_lower(channelname)) + if irc_lower(channelname) in self.channels: + channel = self.channels[irc_lower(channelname)] + else: + channel = Channel(self, channelname) + self.channels[irc_lower(channelname)] = channel + return channel def get_motd_lines(self): if self.motdfile: @@ -644,13 +695,6 @@ class Server(object): del self.nicknames[irc_lower(oldnickname)] self.nicknames[irc_lower(client.nickname)] = client - def add_member_to_channel(self, client, channelname): - if irc_lower(channelname) in self.channels: - channel = self.channels[irc_lower(channelname)] - else: - channel = self.add_channel(channelname) - channel.add_member(client) - def remove_member_from_channel(self, client, channelname): if irc_lower(channelname) in self.channels: channel = self.channels[irc_lower(channelname)] @@ -666,11 +710,6 @@ class Server(object): del self.nicknames[irc_lower(client.nickname)] del self.clients[client.socket] - def add_channel(self, name): - channel = Channel(self, channelname) - self.channels[irc_lower(channelname)] = channel - return channel - def remove_channel(self, channel): del self.channels[irc_lower(channel.name)] @@ -751,7 +790,11 @@ def main(argv): "--ports", metavar="X", help="listen to ports X (a list separated by comma or whitespace);" - " default: 6667.") + " default: 6667") + op.add_option( + "--statedir", + metavar="X", + help="save persistent channel state (topic, key) in directory X") op.add_option( "--verbose", action="store_true", -- cgit v1.2.3