#!/usr/bin/python
#
#
# Copyright (C) 2009  Larne Pekowsky, Ping Wei
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


"""
This provides the means to answer several questions posed against either the
segment database or a collection of DMT XML files:

  * What DQ flags exist in the database? ligolw_segment_query --show-types
  * When was a given DQ flag defined? ligolw_segment_query --query-types 
  * When was a given flag active? ligolw_segment_query --query-segments
"""


from __future__ import print_function
from optparse import OptionParser
import sqlite3
import sys
import os
import operator
import pwd
import glob
import time
import socket
import tempfile
import urllib

import ligo.segments

from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils

from glue.ligolw.utils import ligolw_add
from glue.ligolw.utils import process
from glue.segmentdb import query_engine
from glue.segmentdb import segmentdb_utils

from glue.ligolw.utils import ligolw_sqlite
from glue.ligolw import dbtables

from glue import git_version
import six

PROGRAM_NAME = sys.argv[0].replace('./','')
PROGRAM_PID  = os.getpid()
try:
        USER_NAME = os.getlogin()
except:
        USER_NAME = pwd.getpwuid(os.getuid())[0]


__author__  = "Larne Pekowsky <lppekows@physics.syr.edu>, Ping Wei <piwei@syr.edu>"
__version__ = "git id %s" % git_version.id
__date__ = git_version.date


#
# =============================================================================
#
#                                 Command Line
#
# =============================================================================
#


def parse_command_line():
    """
    Parse the command line, return an options object
    """

    parser = OptionParser(
        version = "Name: %%prog\n%s" % git_version.verbose_msg,
        usage       = "%prog [ --version | --ping | --show-types | --query-types | --query-segments ]  [ --segment | --database | --dmt-files ] options ",
        description = "Performs a number of queries against either a set of DMT files or a segment database"
	)


    # Major modes
    parser.add_option("-p", "--ping",           action = "store_true", help = "Ping the target server")
    parser.add_option("-y", "--show-types",     action = "store_true", help = "Returns a xml table containing segment type information: ifos, name, version, segment_definer.comment, segment_summary.start_time, segment_summary.end_time, segment_summary.comment")
    parser.add_option("-u", "--query-types",    action = "store_true", help = "Returns a ligolw document whose segment_definer table includes all segment types defined in the given period and included by include-segments and whose segment_summary table indicates the times for which those segments are defined.")
    parser.add_option("-q", "--query-segments", action = "store_true", help = "Returns a ligolw document whose segment table contains the times included by the include-segments flag and excluded by exclude-segments")

    # Time options
    parser.add_option("-s", "--gps-start-time", metavar = "gps_start_time", help = "Start of GPS time range")
    parser.add_option("-e", "--gps-end-time",   metavar = "gps_end_time", help = "End of GPS time range")


    # Data location options
    parser.add_option("-t", "--segment-url",    metavar = "segment_url", help = "Segment URL. Users have to specify either 'https://' for a secure connection or 'http://' for an insecure connection in the segment database url. For example, '--segment-url=https://segdb.ligo.caltech.edu'. No need to specify port number. ")
    parser.add_option("-d", "--database",   metavar = "use_database", action = "store_true", help = "use database specified by environment variable S6_SEGMENT_SERVER. For example, 'S6_SEGMENT_SERVER=https://segdb.ligo.caltech.edu'")
    parser.add_option("-f", "--dmt-files",   metavar = "use_files", action = "store_true", help = "use files in directory specified by environment variable ONLINEDQ, for example, 'ONLINEDQ=file:///path_to_dmt'. 'file://' is the prefix, the acutal directory to DMT xml files starts with '/'.")


    # Other options
    parser.add_option("-a", "--include-segments", metavar = "include_segments", help = "This option expects a comma separated list of a colon separated sublist of interferometer, segment type, and version. The union of segments from all types and versions specified is returned. Use --show-types to see what types are available.   For example: --include-segment-types H1:DMT-SCIENCE:1,H1:DMT-INJECTION:2 will return the segments for which H1 is in either SCIENCE version 1 or INJECTION version 2 mode. If version information is not provided, the union of the segments of the latest version of requested segment type(s) will be returned.")

    parser.add_option("-b", "--exclude-segments", metavar = "exclude_segments", help = "This option has to be used in conjunction with --include-segment-types --exclude-segment-types subtracts the union of unwanted segments from the specified types from the results of --include-segment-types. If version information is not provided, --exclude-segment-types subtracts the union of segments from the latest version of the specified segment types. For example, --include-segment-types H1:DMT-SCIENCE:1,H1:DMT-INJECTION:2 --exclude-segment-types H1:DMT-WIND:1,H1:DMT-NOT_LOCKED:2,H2:DMT-NOT_LOCKED:2 will subtract the union of segments which H1 is in version 1 WIND and H1,H2 is version 2 NOT_LOCKED from the result of --include-segment-types H1:DMT-SCIENCE:1,H1:DMT-INJECTION:2")


    parser.add_option("-S", "--strict-off", metavar = "use_strict", action = "store_true", help = "The default behavior is to truncate segments so that returned segments are entirely in the interval [gps-start-time, gps-end-time).  However if this option is given, the entire non-truncated segment is returned if any part of it overlaps the interval.")

    parser.add_option("-n", "--result-name", metavar = "result_name", default = "RESULT", help = "Name for result segment definer (default = RESULT)") 

    parser.add_option("-o", "--output-file",   metavar = "output_file", help = "File to which output should be written.  Defaults to stdout.")
    
    options, others = parser.parse_args()

    # Make sure we have exactly one thing to do
    count = 0
    for arg in [options.ping, options.query_types, options.show_types, options.query_segments]:
        if arg:
            count += 1
            
    if count != 1:
        raise ValueError("Exactly one of [ --ping | --show-types | --query-types | --query-segments ] must be provided")
    
    
    # Make sure we have required arguments
    database_location = None
    file_location     = None

    # Make sure we know who to contact for data
    if options.segment_url:
        if options.segment_url.startswith('ldbd') or options.segment_url.startswith('http'):
            database_location = options.segment_url
        elif options.segment_url.startswith('file:'):
            file_location = options.segment_url[len('file://'):]
        else:
            tmp_dir = tempfile.mkdtemp()
        
            # Grab the part of the name after the last slash
            pos     = options.segment_url[::-1].find('/')
            fname   = (pos > -1) and options.segment_url[ -1 * pos:] or "dmt.xml" 

            inurl   = urllib.urlopen(options.segment_url)
            outfile = open(tmp_dir + "/" + fname, 'w')
            for l in inurl:
                print(l, end=' ', file=outfile)

            inurl.close()
            outfile.close()
            file_location = tmp_dir
    elif options.database:
        if 'S6_SEGMENT_SERVER' not in os.environ:
            raise ValueError( "--database specified but S6_SEGMENT_SERVER not set" )
        database_location = os.environ['S6_SEGMENT_SERVER']
    elif options.dmt_files:
        if 'ONLINEDQ' not in os.environ:
            raise ValueError( "--dmt-files specified but ONLINEDQ not set" )
        tmp = os.environ['ONLINEDQ']
        if tmp.startswith('file://'):
            tmp = tmp[len('file://'):]
        file_location = tmp
    else:
        raise ValueError( "One of [ --segment-url | --database | --dmt-files ] must be provided" )
        

    # Unless we're pinging, make sure we have start and end times
    if options.ping:
        if not database_location:
            raise ValueError("--ping requires [ --segment-url https://... | --database ]")
    else:
        if not options.gps_start_time:
            raise ValueError( "missing required argument --gps-start-time" )
    
        if not options.gps_end_time:
            raise ValueError( "missing required argument --gps-end-time" )

        if not options.show_types and not options.include_segments:
            raise ValueError( "missing required argument --include-segments")
    
    return options, database_location, file_location




#
# =============================================================================
#
#                                 General utilities
#
# =============================================================================
#


def seg_spec_to_sql(spec):
    """Given a string of the form ifo:name:version, ifo:name:* or ifo:name
    constructs a SQL caluse to restrict a search to that segment definer"""

    parts = spec.split(':')
    sql   = "(segment_definer.ifos = '%s'" % parts[0]

    if len(parts) > 1 and parts[1] != '*':
        sql += " AND segment_definer.name = '%s'" % parts[1]
        if len(parts) > 2 and parts[2] != '*':
            sql += " AND segment_definer.version = %s" % parts[2]

    sql += ')'

    return sql



#
# The results of show-types is a join against segment_definer and segment
# summary, and so does not fit into an existing table type.  So here we
# define a new type so that the ligolw routines can generate the XML
#
class ShowTypesResultTable(table.Table):
    tableName = "show_types_result"

    validcolumns = {
        "ifos": "lstring",
        "name": "lstring",
        "version": "int_4s",
        "segment_definer_comment": "lstring",
        "segment_summary_start_time": "int_4s",
        "segment_summary_end_time": "int_4s",
        "segment_summary_comment": "lstring"
        }
    


class ShowTypesResult(object):
    __slots__ = tuple(ShowTypesResultTable.validcolumns.keys())

    def get_pyvalue(self):
        if self.value is None:
            return None
        return ligolwtypes.ToPyType[self.type or "lstring"](self.value)


ShowTypesResultTable.RowType = ShowTypesResult



#
# =============================================================================
#
#                          Methods that implement major modes
#
# =============================================================================
#
def run_show_types(doc, connection, engine, gps_start_time, gps_end_time, included_segments_string, excluded_segments_string):
    resulttable = lsctables.New(ShowTypesResultTable)
    doc.childNodes[0].appendChild(resulttable)
    
    sql = """SELECT segment_definer.ifos, segment_definer.name, segment_definer.version,
                 (CASE WHEN segment_definer.comment IS NULL THEN '-' WHEN segment_definer.comment IS NOT NULL THEN segment_definer.comment END),
                 segment_summary.start_time, segment_summary.end_time,
                 (CASE WHEN segment_summary.comment IS NULL THEN '-' WHEN segment_summary.comment IS NOT NULL THEN segment_summary.comment END)
          FROM  segment_definer, segment_summary
          WHERE segment_definer.segment_def_id = segment_summary.segment_def_id
          AND   NOT (segment_summary.start_time > %d OR %d > segment_summary.end_time)
          """ % (gps_end_time, gps_start_time)

    rows = engine.query(sql)

    seg_dict = {}

    for row in rows:
        ifos, name, version, segment_definer_comment, segment_summary_start_time, segment_summary_end_time, segment_summary_comment = row
        key = (ifos, name, version, segment_definer_comment, segment_summary_comment)
        if key not in seg_dict:
            seg_dict[key] = []

        seg_dict[key].append(ligo.segments.segment(segment_summary_start_time, segment_summary_end_time))

    for key, value in six.iteritems(seg_dict):
        segmentlist = ligo.segments.segmentlist(value)
        segmentlist.coalesce()

        for segment in segmentlist:
            result = ShowTypesResult()
            result.ifos, result.name, result.version, result.segment_definer_comment, result.segment_summary_comment = key
            result.segment_summary_start_time, result.segment_summary_end_time = segment
            result.ifos = result.ifos.strip()
        
            resulttable.append(result)

    engine.close()


    




def run_query_types(doc, proc_id, connection, engine, gps_start_time, gps_end_time, included_segments):
    query_segment = ligo.segments.segmentlist([ligo.segments.segment(gps_start_time, gps_end_time)])

    sql = """SELECT segment_definer.ifos, segment_definer.name,segment_definer.version,
           (CASE WHEN segment_definer.comment IS NULL THEN '-' WHEN segment_definer.comment IS NOT NULL THEN segment_definer.comment END),
           segment_summary.start_time, segment_summary.end_time,
           (CASE WHEN segment_summary.comment IS NULL THEN '-' WHEN segment_summary.comment IS NOT NULL THEN segment_summary.comment END)
    FROM segment_definer, segment_summary
    WHERE segment_definer.segment_def_id = segment_summary.segment_def_id
    AND NOT(%d > segment_summary.end_time OR segment_summary.start_time > %d)
    """ % (gps_start_time, gps_end_time)

    type_clauses = list(map(seg_spec_to_sql, included_segments.split(',')))

    if type_clauses != []:
        sql += " AND (" + "OR ".join(type_clauses) + ")"


    segment_types = {}

    for row in engine.query(sql):
        sd_ifo, sd_name, sd_vers, sd_comment, ss_start, ss_end, ss_comment = row
        key = (sd_ifo, sd_name, sd_vers, sd_comment, ss_comment)
        if key not in segment_types:
            segment_types[key] = ligo.segments.segmentlist([])
        segment_types[key] |= ligo.segments.segmentlist([ligo.segments.segment(ss_start, ss_end)])

    engine.close()

    # Create segment definer and segment_summary tables
    seg_def_table = lsctables.New(lsctables.SegmentDefTable, columns = ["process_id", "segment_def_id", "ifos", "name", "version", "comment"])
    doc.childNodes[0].appendChild(seg_def_table)

    seg_sum_table = lsctables.New(lsctables.SegmentSumTable, columns = ["process_id", "segment_sum_id", "start_time", "start_time_ns", "end_time", "end_time_ns", "comment", "segment_def_id"])

    doc.childNodes[0].appendChild(seg_sum_table)

    for key in segment_types:
        # Make sure the intervals fall within the query window and coalesce
        segment_types[key].coalesce()
        segment_types[key] &= query_segment

        seg_def_id                     = seg_def_table.get_next_id()
        segment_definer                = lsctables.SegmentDef()
        segment_definer.process_id     = proc_id
        segment_definer.segment_def_id = seg_def_id
        segment_definer.ifos           = key[0]
        segment_definer.name           = key[1]
        segment_definer.version        = key[2]
        segment_definer.comment        = key[3]

        seg_def_table.append(segment_definer)

        # add each segment summary to the segment_summary_table

        for seg in segment_types[key]:
            segment_sum            = lsctables.SegmentSum()
            segment_sum.comment    = key[4]
            segment_sum.process_id = proc_id
            segment_sum.segment_def_id = seg_def_id
            segment_sum.segment_sum_id = seg_sum_table.get_next_id()
            segment_sum.start_time = seg[0]
            segment_sum.start_time_ns = 0
            segment_sum.end_time   = seg[1]
            segment_sum.end_time_ns = 0

            seg_sum_table.append(segment_sum)



def run_query_segments(doc, process_id, engine, gps_start_time, gps_end_time, include_segments, exclude_segments, result_name):
    segdefs = []

    for included in include_segments.split(','):
        spec = included.split(':')

        if len(spec) < 2 or len(spec) > 3:
            print("Included segements must be of the form ifo:name:version or ifo:name:*", file=sys.stderr)
            sys.exit(1)

        ifo     = spec[0]
        name    = spec[1]
        if len(spec) is 3 and spec[2] is not '*':
            version = int(spec[2])
            if version < 1:
                print("Segment version numbers must be greater than zero", file=sys.stderr)
                sys.exit(1)
        else:
            version = '*'

        segdefs += segmentdb_utils.expand_version_number(engine, (ifo, name, version, gps_start_time, gps_end_time, 0, 0) )

    found_segments = segmentdb_utils.query_segments(engine, 'segment', segdefs)
    found_segments = reduce(operator.or_, found_segments).coalesce()

    # We could also do:
    segment_summaries = segmentdb_utils.query_segments(engine, 'segment_summary', segdefs)

    # And we could write out everything we found
    segmentdb_utils.add_segment_info(doc, process_id, segdefs, None, segment_summaries)


    # Do the same for excluded
    if exclude_segments:
        ex_segdefs = []

        for excluded in exclude_segments.split(','):
            spec = excluded.split(':')

            if len(spec) < 2:
                print("Excluded segements must be of the form ifo:name:version or ifo:name:*", file=sys.stderr)
                sys.exit(1)

            ifo     = spec[0]
            name    = spec[1]
            version = len(spec) > 2 and spec[2] or '*'

            ex_segdefs += segmentdb_utils.expand_version_number(engine, (ifo, name, version, gps_start_time, gps_end_time, 0, 0) )


        excluded_segments = segmentdb_utils.query_segments(engine, 'segment', ex_segdefs)
        excluded_segments = reduce(operator.or_, excluded_segments).coalesce()

        found_segments.coalesce()
        found_segments -= excluded_segments



    # Add the result type to the segment definer table
    seg_name   = result_name
    seg_def_id = segmentdb_utils.add_to_segment_definer(doc, process_id, ifo, seg_name, 1)

    # and segment summary
    segmentdb_utils.add_to_segment_summary(doc, process_id, seg_def_id, [[gps_start_time, gps_end_time]])

    # and store the segments
    segmentdb_utils.add_to_segment(doc, process_id, seg_def_id, found_segments)
           


#
# =============================================================================
#
#                                 XML/File routines
#
# =============================================================================
#
class ContentHandler(ligolw.LIGOLWContentHandler):
    connection = None

def setup_files(dir_name, gps_start_time, gps_end_time):
    # Filter out the ones that are outside our time range
    xml_files = segmentdb_utils.get_all_files_in_range(dir_name, gps_start_time, gps_end_time)

    handle, temp_db  = tempfile.mkstemp(suffix='.sqlite')
    os.close(handle)

    target     = dbtables.get_connection_filename(temp_db, None, True, False)
    connection = ligolw_sqlite.setup(target)

    ContentHandler.connection = connection
    dbtables.use_in(ContentHandler)
    # With ER6 Glue change:
    ligolw_sqlite.insert_from_urls(xml_files, ContentHandler)

    # Pre-ER6 Kipp Glue change:
    #ligolw_sqlite.insert_from_urls(connection, xml_files) # [temp_xml])

    segmentdb_utils.ensure_segment_table(connection)
        
    return temp_db, connection
    



#
# =============================================================================
#
#                                     Main
#
# =============================================================================
#

if __name__ == '__main__':
    options, database_location, file_location  = parse_command_line()    

    # Ping the database and exit if requested
    if options.ping:
        connection = segmentdb_utils.setup_database(database_location)
        print(connection.ping())
        sys.exit( 0 )

    gps_start_time = int(options.gps_start_time)
    gps_end_time   = int(options.gps_end_time)

    # set up the response
    doc = ligolw.Document()
    doc.appendChild(ligolw.LIGO_LW())
    process_id = process.register_to_xmldoc(doc, PROGRAM_NAME, options.__dict__, version = git_version.id, cvs_entry_time = __date__).process_id

    temp_files = False

    if database_location:
        connection = segmentdb_utils.setup_database(database_location)
        engine     = query_engine.LdbdQueryEngine(connection)
    else:
        temp_db, connection = setup_files(file_location, gps_start_time, gps_end_time)
        engine     = query_engine.SqliteQueryEngine(connection)
        temp_files = True

    
    if options.show_types:
        run_show_types(doc, connection, engine, gps_start_time, gps_end_time,
                       options.include_segments,options.exclude_segments)
        

    if options.query_segments:
        run_query_segments(doc, process_id, engine, gps_start_time, gps_end_time,
                           options.include_segments, options.exclude_segments,
                           options.result_name)


    if options.query_types:
        run_query_types(doc, process_id, connection, engine, gps_start_time, gps_end_time, options.include_segments)


    gz = (options.output_file is not None and
          options.output_file.endswith(".gz"))
    utils.write_filename(doc, options.output_file, gz=gz)

    # Clean up
    if temp_files:
        os.remove(temp_db)


