# chgsupport - functionality to support cHg
#
# Copyright 2011 Yuya Nishihara <yuya@tcha.org>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2 or any later version.

"""functionality to support cHg

hg serve --cmdserver=unix
    communicate with cmdserver via unix domain socket

This extends channels and commands of cmdserver:

's'-channel
    handles util.system() request

'P'-channel
    handles getpass() request

'chdir' command
    change current directory

'getpager' command
    checks if pager is enabled and which pager should be executed

'getpid' command
    get process id of the server (worker, master)

'setchguiattr' command
    set initial value of ui attributes

'setenv' command
    replace os.environ completely
"""

import errno, os, re, signal, socket, struct, sys, tempfile
from mercurial import cmdutil, commands, commandserver, dispatch, encoding, \
                      error, extensions, i18n, scmutil, util
from mercurial.node import hex
from mercurial.i18n import _

testedwith = '3.0 3.1'

class channeledinput(commandserver.channeledinput):
    def readpass(self):
        return self._read(self.maxchunksize, 'P').rstrip('\n')

class channeledoutput(commandserver.channeledoutput):
    def flush(self):
        # tell flush() request by 0-length data
        self.out.write(struct.pack('>cI', self.channel, 0))
        self.out.flush()

class channeledsystem(object):
    """Request to execute system() in the following format:

    command length (unsigned int),
    cmd

    and waits:

    exitcode length (unsigned int),
    exitcode (int)
    """
    def __init__(self, in_, out, channel):
        self.in_ = in_
        self.out = out
        self.channel = channel

    # copied from mercurial/util.py
    def __call__(self, cmd, environ={}, cwd=None, onerr=None, errprefix=None,
                 out=None):
        # XXX no support for environ, cwd and out

        origcmd = cmd
        cmd = util.quotecommand(cmd)

        self.out.write(struct.pack('>cI', self.channel, len(cmd)))
        self.out.write(cmd)
        self.out.flush()

        length = self.in_.read(4)
        length = struct.unpack('>I', length)[0]
        if length != 4:
            rc = 255
        else:
            rc = struct.unpack('>i', self.in_.read(4))[0]

        if rc and onerr:
            errmsg = '%s %s' % (os.path.basename(origcmd.split(None, 1)[0]),
                                util.explainexit(rc)[0])
            if errprefix:
                errmsg = '%s: %s' % (errprefix, errmsg)
            try:
                onerr.warn(errmsg + '\n')
            except AttributeError:
                raise onerr(errmsg)
        return rc

_envmodstoreload = [
    # recreate gettext 't' by reloading i18n
    ('LANG LANGUAGE LC_MESSAGES'.split(), [encoding, i18n]),
    ('HGENCODING HGENCODINGMODE HGENCODINGAMBIGUOUS'.split(), [encoding]),
    ]

def _listmodstoreload(newenv):
    """List modules must be reloaded after os.environ change"""
    toreload = set()
    for envkeys, mods in _envmodstoreload:
        if util.any(newenv.get(k) != os.environ.get(k) for k in envkeys):
            toreload.update(mods)
    return toreload

def _fixdefaultencoding():
    """Apply new default encoding to commands table"""
    newdefaults = {'encoding': encoding.encoding,
                   'encodingmode': encoding.encodingmode}
    for i, opt in enumerate(commands.globalopts):
        name = opt[1]
        newdef = newdefaults.get(name)
        if newdef is not None:
            commands.globalopts[i] = opt[:2] + (newdef,) + opt[3:]

_envvarre = re.compile(r'\$[a-zA-Z_]+')

def _clearenvaliases(cmdtable):
    """Remove stale command aliases referencing env vars; variable expansion
    is done at dispatch.addaliases()"""
    for name, tab in cmdtable.items():
        cmddef = tab[0]
        if (isinstance(cmddef, dispatch.cmdalias)
            and not cmddef.definition.startswith('!')  # shell alias
            and _envvarre.search(cmddef.definition)):
            del cmdtable[name]

class chgcmdserver(commandserver.server):
    def __init__(self, sui, ui, repo, fin, fout, masterpid):
        super(chgcmdserver, self).__init__(ui, repo, mode='pipe')
        self._sui = sui  # ui for server output
        self.cerr = channeledoutput(fout, fout, 'e')
        self.cout = channeledoutput(fout, fout, 'o')
        self.cin = channeledinput(fin, fout, 'I')
        self.cresult = commandserver.channeledoutput(fout, fout, 'r')
        self.csystem = channeledsystem(fin, fout, 's')
        self.client = fin
        self._masterpid = masterpid

    def chdir(self):
        """Change current directory

        Note that the behavior of --cwd option is bit different from this.
        It does not affect --config parameter.
        """
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            return
        path = self._read(length)
        self._sui.debug('chdir to %r\n' % path)
        os.chdir(path)

    def getpager(self):
        """Read cmdargs and write pager command to r-channel if enabled

        If pager isn't enabled, this writes '\0' because channeledoutput
        does not allow to write empty data.
        """
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            args = []
        else:
            args = self._read(length).split('\0')

        try:
            pagermod = extensions.find('pager')
        except KeyError:
            self.cresult.write('\0')
            return

        pagercmd = self.ui.config('pager', 'pager', os.environ.get('PAGER'))
        if not pagercmd:
            self.cresult.write('\0')
            return

        try:
            cmd, _func, args, options, _cmdoptions = dispatch._parse(self.ui,
                                                                     args)
        except (error.AmbiguousCommand, error.CommandError,
                error.UnknownCommand):
            self.cresult.write('\0')
            return

        # duplicated from hgext/pager.py
        attend = self.ui.configlist('pager', 'attend', pagermod.attended)
        auto = options['pager'] == 'auto'
        always = util.parsebool(options['pager'])
        if (always or auto and
            (cmd in attend or
             (cmd not in self.ui.configlist('pager', 'ignore')
              and not attend))):
            self.cresult.write(pagercmd)
        else:
            self.cresult.write('\0')

    def getpid(self):
        """Write pid of the server (worker, master)"""
        self.cresult.write(struct.pack('>ii', os.getpid(), self._masterpid))

    def runcommand(self):
        # reset time-stamp so that "progress.delay" can take effect
        progbar = getattr(self.ui, '_progbar', None)
        if progbar:
            self._sui.debug('reset progbar\n')
            progbar.resetstate()
        # swap stderr so that progress output can be redirected to client
        olderr = sys.stderr
        sys.stderr = self.cerr
        try:
            super(chgcmdserver, self).runcommand()
        finally:
            sys.stderr = olderr

    def setchguiattr(self):
        """Set ui attributes if not configured"""
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            return
        s = self._read(length)
        for l in s.splitlines():
            k, v = l.split(': ', 1)
            s, n = k.split('.', 1)
            if self.ui.config(s, n) is None:
                self._sui.debug('set ui %r: %r\n' % (k, v))
                self.ui.setconfig(s, n, v)

        # reload config so that ui.plain() takes effect
        for f in scmutil.rcpath():
            self.ui.readconfig(f, trust=True)

        # skip initial uisetup() by 'mode=chgauto'
        if self.ui.config('color', 'mode') == 'chgauto':
            self.ui.setconfig('color', 'mode', 'auto')

    def setenv(self):
        """Clear and update os.environ

        Note that not all variables can make an effect on the running process.
        """
        length = struct.unpack('>I', self._read(4))[0]
        if not length:
            return
        s = self._read(length)
        try:
            newenv = dict(l.split('=', 1) for l in s.split('\0'))
        except ValueError:
            raise ValueError('unexpected value in setenv request')

        modstoreload = _listmodstoreload(newenv)

        if self._sui.debugflag:
            for k in sorted(set(os.environ.keys() + newenv.keys())):
                ov, nv = os.environ.get(k), newenv.get(k)
                if ov != nv:
                    self._sui.debug('change env %r: %r -> %r\n' % (k, ov, nv))
        os.environ.clear()
        os.environ.update(newenv)

        for mod in modstoreload:
            self._sui.debug('reload %s module\n' % mod.__name__)
            reload(mod)
        if encoding in modstoreload:
            _fixdefaultencoding()

        _clearenvaliases(commands.table)

    capabilities = commandserver.server.capabilities.copy()
    capabilities.update({'chdir': chdir,
                         'getpager': getpager,
                         'getpid': getpid,
                         'runcommand': runcommand,
                         'setchguiattr': setchguiattr,
                         'setenv': setenv})

def _wrapui(ui):
    class chgui(ui.__class__):
        csystem = None  # should be assigned later

        def copy(self):
            o = super(chgui, self).copy()
            o.csystem = self.csystem
            # progbar is a singleton; we want to make it see the last config
            progbar = getattr(o, '_progbar', None)
            if progbar:
                progbar.ui = o
            return o

        def termwidth(self):
            n = self.configint('ui', '_termwidth')
            if n and n > 0:
                return n
            else:
                return super(chgui, self).termwidth()

        # copied from mercurial/ui.py
        def edit(self, text, user, extra={}, editform=None):
            (fd, name) = tempfile.mkstemp(prefix="hg-editor-", suffix=".txt",
                                          text=True)
            try:
                f = os.fdopen(fd, "w")
                f.write(text)
                f.close()

                environ = {'HGUSER': user}
                if 'transplant_source' in extra:
                    environ.update(
                        {'HGREVISION': hex(extra['transplant_source'])})
                for label in ('source', 'rebase_source'):
                    if label in extra:
                        environ.update({'HGREVISION': extra[label]})
                        break
                if editform:
                    environ.update({'HGEDITFORM': editform})

                editor = self.geteditor()

                self.csystem("%s \"%s\"" % (editor, name),
                             environ=environ,
                             onerr=util.Abort, errprefix=_("edit failed"),
                             out=self.fout)

                f = open(name)
                t = f.read()
                f.close()
            finally:
                os.unlink(name)

            return t

        # copied from mercurial/ui.py
        def getpass(self, prompt=None, default=None):
            if not self.interactive():
                return default
            try:
                self.fout.write(prompt or _('password: '))
                self.fout.flush()
                return self.fin.readpass()
            except EOFError:
                raise util.Abort(_('response expected'))

        # copied from mercurial/ui.py
        def plain(self, feature=None):
            plain = self.configbool('ui', '_plain')
            plainexcept = self.configlist('ui', '_plainexcept')
            if not plain and not plainexcept:
                return False
            if feature and plainexcept:
                return feature not in plainexcept
            return True

    ui.__class__ = chgui

class service(object):
    """Forking server for handling connection from cHg client"""
    def __init__(self, ui, sockpath):
        self._ui = ui
        self._sockpath = sockpath

    def init(self):
        self._sock = socket.socket(socket.AF_UNIX)
        self._sock.bind(self._sockpath)
        self._masterpid = os.getpid()
        self._workerpids = []
        util.setsignalhandler()
        signal.signal(signal.SIGCHLD, self._waitworker)

    def run(self):
        self._sock.listen(1)
        self._ui.note(_('listening at %s (pid=%d)\n')
                      % (self._sockpath, self._masterpid))

        try:
            try:
                self._serve()
            except (error.SignalInterrupt, KeyboardInterrupt):
                self._killworkers()
        finally:
            self._sock.close()
            os.unlink(self._sockpath)

    def _serve(self):
        while True:
            try:
                conn, _addr = self._sock.accept()
            except socket.error, inst:
                if inst.errno != errno.EINTR:
                    raise
                continue
            pid = os.fork()
            if pid == 0:
                try:
                    self._handlerequest(conn)
                    conn.close()
                    os._exit(0)
                except KeyboardInterrupt:
                    os._exit(255)
            self._ui.debug('forked worker process (pid=%d)\n' % pid)
            self._workerpids.append(pid)
            conn.close()  # release handle in parent process

    def _waitworker(self, signum, frame):
        while self._workerpids:
            try:
                pid, _status = os.waitpid(-1, os.WNOHANG)
            except OSError, inst:
                if inst.errno != errno.ECHILD:
                    raise
                pid = 0
            if pid == 0:
                return
            self._ui.debug('worker process exited (pid=%d)\n' % pid)
            self._workerpids.remove(pid)

    def _killworkers(self):
        for pid in self._workerpids:
            self._ui.debug('terminate worker process (pid=%d)\n' % pid)
            os.kill(pid, signal.SIGTERM)

        try:
            while self._workerpids:
                pid, _status = os.wait()
                self._workerpids.remove(pid)
            self._ui.debug('terminated all worker processes\n')
        except OSError, err:
            if err.errno != errno.ECHILD:
                raise
            self._ui.debug('got ECHILD while terminating worker processes\n')

    def _handlerequest(self, conn):
        self._ui.note('accepted connection (pid=%d)\n' % os.getpid())
        fin = conn.makefile('rb')
        fout = conn.makefile('wb')
        try:
            self._runcmdserver(fin, fout)
        finally:
            fin.close()
            # implicit flush() may cause another EPIPE
            try:
                fout.close()
            except IOError, err:
                if err.errno != errno.EPIPE:
                    raise

    def _runcmdserver(self, fin, fout):
        ui = self._ui.copy()
        # don't retain command-line opts of the service
        for opt in ('verbose', 'debug', 'quiet', 'traceback'):
            ui.setconfig('ui', opt, None)
        ui.setconfig('progress', 'assume-tty', None)
        _wrapui(ui)
        sv = chgcmdserver(self._ui, ui, None, fin, fout,
                          masterpid=self._masterpid)
        ui.csystem = sv.csystem

        origcwd = os.getcwd()
        try:
            try:
                sv.serve()
                self._ui.debug('finished (pid=%d)\n' % os.getpid())
            except IOError, err:
                if err.errno == errno.EPIPE:
                    self._ui.debug('finished by broken pipe (pid=%d)\n'
                                   % os.getpid())
                else:
                    raise
        finally:
            os.chdir(origcwd)

def _disablepager():
    try:
        pagermod = extensions.find('pager')
    except KeyError:
        return
    # _runpager(p) or _runpager(ui, p)
    extensions.wrapfunction(pagermod, '_runpager', lambda orig, *args: None)

def _serve(orig, ui, repo, **opts):
    mode = opts['cmdserver']
    if not mode or mode != 'unix':
        return orig(ui, repo, **opts)

    _disablepager()

    # it might be confusing to reuse --port option
    sockpath = opts['port']
    if not sockpath:
        raise util.Abort(_('no socket path specified with --port'))

    svc = service(ui, sockpath)
    cmdutil.service(opts, initfn=svc.init, runfn=svc.run,
                    logfile=opts['errorlog'])

def uisetup(ui):
    extensions.wrapcommand(commands.table, 'serve', _serve)
