#!/usr/bin/env python3
# encoding=UTF-8

# Copyright © 2008-2016
#   Piotr Lewandowski <piotr.lewandowski@gmail.com>,
#   Jakub Wilk <jwilk@jwilk.net>.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License, version 2, as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

from __future__ import print_function

__author__ = ('Jakub Wilk', 'Piotr Lewandowski')
__version__ = '1.6'

from nntplib import NNTP, NNTP_PORT, NNTPTemporaryError, NNTPError
import argparse
import email
import email.generator
import errno
import functools
import io
import logging
import logging.handlers
import mailbox
import nntplib
import os
import os.path
import signal
import socket
import sys

import plugins
import utils

# nntplib limits line length to 2048 bytes to prevent denial of service
# (CVE-2013-1752). But the line length limit doesn't buy us much, because
# sinntp loads the whole message into memory anyway. So let's lift the limit to
# 1 megabyte.
#
# References:
# * https://github.com/jwilk/sinntp/issues/9
# * https://bugs.python.org/issue16040
nntplib._MAXLINE = 1 << 20

class Config(object):

    def __init__(self, hostname, port):
        self._hostname = hostname
        self._port = port
        self._root = os.getenv('SINNTP_HOME')
        if self._root:
            logging.warn('$SINNTP_HOME is deprecated in favor of $XDG_DATA_HOME. See the NEWS file for details.')
            return
        self._root = os.path.expanduser('~/.sinntp/')
        if os.path.exists(self._root):
            logging.warn('$HOME/.sinntp/ is deprecated in favor of $XDG_DATA_HOME. See the NEWS file for details.')
            return
        self._root = utils.xdg.save_data_path('sinntp')

    def _get_file_name(self, name):
        base_name = '%s@%s' % (name, self._hostname)
        if self._port != NNTP_PORT:
            base_name += '_%d' % self._port
        return os.path.join(self._root, base_name)

    def __getitem__(self, name):
        try:
            file = open(self._get_file_name(name), 'rt', encoding='ASCII')
        except IOError as ex:
            if ex.errno == errno.ENOENT:
                return 0
            raise
        try:
            return int(file.read())
        finally:
            file.close()

    def __setitem__(self, name, value):
        path = self._get_file_name(name)
        file = open(path + '_tmp', 'wt', encoding='ASCII')
        try:
            file.write(str(value))
            os.fsync(file.fileno())
        finally:
            file.close()
        os.rename(path + '_tmp', path)

class Command(object):

    def get_option_parser(self):
        o = argparse.ArgumentParser(usage=self.__doc__.rstrip())
        o.add_argument('--version', action='version', version='%(prog)s ' + __version__, help='show program\'s version number and exit')
        o.add_argument('-s', '--syslog', dest='syslog', action='store_true', help='use syslog for logging')
        o.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='be more verbose')
        o.add_argument('-q', '--quiet', dest='verbose', action='store_false', help='be less verbose')
        o.add_argument('-p', '--plugin', dest='plugins', action='append', metavar='PLUGIN', help='use plugin')
        o.add_argument('-S', '--server', metavar='HOST[:PORT]', dest='server', action='store', help='specify server address')
        o.add_argument('-U', '--username', dest='username', action='store', help='specify username')
        o.add_argument('-P', '--password', dest='password', action='store', help='specify password')
        o.add_argument('--no-netrc', dest='netrc', action='store_false', help='ignore credentials in ~/.netrc')
        o.add_argument('-t', '--timeout', dest='timeout', action='store', type=int, help='specify connection timeout')
        return o

    def __init__(self, argv):
        oparser = self.get_option_parser()
        self.options, self.args = oparser.parse_known_args(argv)
        try:
            self.command_name = self.args.pop(0)
        except IndexError:
            oparser.error('A command is expected')
        self.plugins = []
        for spec in self.options.plugins or ():
            spec = spec.split(':')
            name = spec.pop(0)
            args = [arg for arg in spec if '=' not in arg]
            kwargs = dict(arg.split('=', 1) for arg in spec if '=' in arg)
            self.plugins += functools.partial(plugins.__dict__[name], *args, **kwargs),
        self.oparser = oparser

    def __call__(self):
        pass

class Pull(Command):
    '''
    nntp-pull [options] groupname[>filename] [groupname[>filename] ...]
    '''

    def get_option_parser(self):
        o = Command.get_option_parser(self)
        o.add_argument('--limit', dest='limit', type=int, action='store', metavar='N', help='pull at most N messages')
        o.add_argument('--reget', dest='reget', action='store_true', help='start from the first available message')
        return o

    def __init__(self, argv):
        Command.__init__(self, argv)
        if not self.args:
            self.oparser.error('At least one group name is required')
        self.groups = self.args
        del self.args

    def fetch(self, connection, group_name, start):
        logging.info('Looking for group %r.', group_name)
        response, count, first, last, name = connection.group(group_name)
        count, first, last = (int(x) for x in (count, first, last))
        i = max(start, first)
        count = max(last - i + 1, 0)
        logging.info('%(count)s message%(plural)s to download.' % dict(
            count = count if count else 'No',
            plural = 's' if count != 1 else ''
        ))
        while i <= last:
            try:
                connection.stat(str(i))
            except NNTPTemporaryError:
                i += 1
            else:
                break
        else:
            return
        no = str(i)
        while True:
            _, (_, message_id, body) = connection.article(no)
            logging.debug('Reading message %s.', message_id)
            yield int(no), body
            try:
                _, no, _ = connection.next()
            except NNTPTemporaryError as ex:
                if ex.response.startswith('421'):
                    break
                raise

    def __call__(self, connection):
        config = Config(connection.host, connection.port)
        for group in self.groups:
            if '>' in group:
                group, mbox_name = group.split('>', 1)
            else:
                mbox_name = group
            mbox = None
            atime = None
            mode = None
            last = None
            start = config[group] if not self.options.reget else 0
            if self.options.limit is not None:
                _, _, _, end, _ = connection.group(group)
                start = max(start, int(end) - self.options.limit + 1)
            try:
                for no, message in self.fetch(connection, group, start):
                    if mbox is None:
                        mbox = mailbox.mbox(mbox_name, create=True)
                        mbox.lock()
                        stat = os.stat(mbox_name)
                        atime = stat.st_atime
                        mode = stat.st_mode
                    message = utils.join_lines(message)
                    message = email.message_from_bytes(message)
                    for plugin in self.plugins:
                        message = plugin(message=message)
                        if message is None:
                            break
                    if message is not None:
                        mbox.add(message)
                    last = no
            finally:
                if mbox is not None:
                    mbox.close()
                    if atime is not None:
                        mtime = os.stat(mbox_name).st_mtime
                        os.utime(mbox_name, (atime, mtime))
                    if mode is not None:
                        os.chmod(mbox_name, mode)
                if last is not None:
                    config[group] = last + 1

class Push(Command):
    '''
    nntp-push [options] [newsgroups...]
    '''

    def __call__(self, connection):
        message = email.message_from_binary_file(sys.stdin.buffer)
        if 'Newsgroups' not in message and self.args:
            message['Newsgroups'] = ','.join(self.args)
        for plugin in self.plugins:
            message = plugin(message=message)
        buffer = io.BytesIO()
        generator = email.generator.BytesGenerator(buffer, mangle_from_=False)
        generator.flatten(message)
        buffer.seek(0)
        connection.post(buffer)

class Get(Command):
    '''
    nntp-get message-id
    '''

    def __init__(self, argv):
        Command.__init__(self, argv)
        if len(self.args) != 1:
            self.oparser.error('A single message-id is required')
        [self.message_id] = self.args
        if '@' not in self.message_id:
            self.oparser.error('Message-id is malformed (\'@\' is missing)')
        if self.message_id[0] + self.message_id[-1] != '<>':
            self.message_id = '<%s>' % self.message_id
        del self.args

    def __call__(self, connection):
        logging.debug('Reading message %s.', self.message_id)
        _, (_, _, message) = connection.article(self.message_id)
        message = utils.join_lines(message)
        sys.stdout.buffer.write(message)

class List(Command):
    '''
    nntp-list
    '''

    def __init__(self, argv):
        Command.__init__(self, argv)
        if len(self.args) != 0:
            self.oparser.error('This command takes no arguments')

    def __call__(self, connection):
        _, groups = connection.list()
        for group, _, _, _ in groups:
            # RFC 3997 §10.2 says that group names SHOULD be restricted to
            # US-ASCII and that 8-bit encodings SHOULD NOT be used.
            # But non-UTF-8 group names have been seen in the wild,
            # so let's handle them gracefully.
            bgroup = group.encode('UTF-8', 'surrogateescape')
            sys.stdout.buffer.write(bgroup)
            sys.stdout.buffer.write(b'\n')

class BaseCommand(Command):

    COMMANDS = dict(
        pull = Pull,
        push = Push,
        get = Get,
        list = List,
    )

    __doc__ = ''.join(
        command.__doc__.rstrip() for command in COMMANDS.values()
    )

    def get_option_parser(self):
        oparser = Command.get_option_parser(self)
        return oparser

    def __init__(self, argv):
        Command.__init__(self, argv)
        try:
            self.command_class = self.COMMANDS[self.command_name]
        except KeyError:
            self.oparser.error('Unknown command: %r.' % self.command_name)
        sys.argv[0] = 'nntp-%s' % self.command_name

def setup_logging(syslog, verbose):
    logger = logging.getLogger()
    format = '%(message)s'
    if not syslog:
        handler = logging.StreamHandler()
    else:
        handler = logging.handlers.SysLogHandler(
            '/dev/log',
            facility = logging.handlers.SysLogHandler.LOG_NEWS
        )
        format = ''.join((sys.argv[0], '[%(process)d]: ', format))
    formatter = logging.Formatter(format, None)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    level = {
        True: logging.DEBUG,
        None: logging.INFO,
        False: logging.ERROR,
    }.get(verbose)
    logger.setLevel(level)

class TerminatedBySignal(Exception):
    pass

def setup_signal_handler(signame):
    def handler(signo, frame):
        raise TerminatedBySignal(signame)
    signal.signal(getattr(signal, signame), handler)

def get_default_nntp_server():
    server = os.getenv('NNTPSERVER')
    if server:
        return server
    try:
        file = open('/etc/news/server', encoding='ASCII')
    except IOError as exc:
        if exc.errno == errno.ENOENT:
            return
        raise
    with file:
        return file.readline().strip()

if __name__ == '__main__':
    argv = sys.argv[1:]
    base_command = BaseCommand(argv)
    command = base_command.command_class(argv)
    setup_signal_handler('SIGTERM')
    setup_logging(command.options.syslog, command.options.verbose)
    if sys.version_info < (3, 2):
        # E-mail and NNTP modules are broken in Python 3.1:
        # https://docs.python.org/dev/whatsnew/3.2.html#new-improved-and-deprecated-modules
        logging.error('Python >= 3.2 is required')
    host = command.options.server or get_default_nntp_server()
    if host is None:
        sys.stderr.write('NNTP server is not specified.\n')
        sys.exit(2)
    host, port = utils.split_host(host, NNTP_PORT)
    socket.setdefaulttimeout(command.options.timeout)
    try:
        logging.info('Connecting to %s:%d...', host, port)
        connection = NNTP(host,
            port=port,
            user=command.options.username,
            password=command.options.password,
            readermode=True,
            usenetrc=command.options.netrc
        )
    except socket.error as e:
        logging.error('Could not connect to %s:%d: %s', host, port, e.strerror)
        sys.exit(3)
    except NNTPError as e:
        logging.error('%s:%d returned an error: %s', host, port, e)
        sys.exit(4)
    try:
        command(connection)
    except NNTPError as e:
        logging.error('NNTP error: %s', e)
        sys.exit(4)
    connection.quit()
    logging.info('Connection to %s:%d closed.', host, port)

# vim:ts=4 sts=4 sw=4 et
