/*
This tool is part of the WhiteboxTools geospatial analysis library.
Authors: Dr. John Lindsay
Created: 27/07/2017
Last Modified: 11/02/2019
License: MIT
*/

use crate::raster::*;
use crate::structures::Array2D;
use crate::tools::*;
use num_cpus;
use std::env;
use std::f64;
use std::io::{Error, ErrorKind};
use std::path;
use std::sync::mpsc;
use std::sync::Arc;
use std::thread;

/// Panchromatic sharpening, or simply pan-sharpening, refers to a range of techniques that can be used to merge
/// finer spatial resolution panchromatic images with coarser spatial resolution multi-spectral images. The
/// multi-spectral data provides colour information while the panchromatic image provides improved spatial information.
/// This procedure is sometimes called image fusion. Jensen (2015) describes panchromatic sharpening in detail.
///
/// Whitebox provides two common methods for panchromatic sharpening including the Brovey transformation and the
/// Intensity-Hue-Saturation (IHS) methods. Both of these techniques provide the best results when the range of
/// wavelengths detected by the panchromatic image overlap significantly with the wavelength range covered by the
/// three multi-spectral bands that are used. When this is not the case, the resulting colour composite will likely
/// have colour properties that are dissimilar to the colour composite generated by the original multispectral images.
/// For Landsat ETM+ data, the panchromatic band is sensitive to EMR in the range of 0.52-0.90 micrometres. This
/// corresponds closely to the green (band 2), red (band 3), and near-infrared (band 4).
///
/// # Reference
/// Jensen, J. R. (2015). Introductory Digital Image Processing: A Remote Sensing Perspective.
///
/// # See Also
/// `CreateColourComposite`
pub struct PanchromaticSharpening {
    name: String,
    description: String,
    toolbox: String,
    parameters: Vec<ToolParameter>,
    example_usage: String,
}

impl PanchromaticSharpening {
    /// Public constructor.
    pub fn new() -> PanchromaticSharpening {
        let name = "PanchromaticSharpening".to_string();
        let toolbox = "Image Processing Tools/Image Enhancement".to_string();
        let description = "Increases the spatial resolution of image data by combining multispectral bands with panchromatic data.".to_string();

        let mut parameters = vec![];
        parameters.push(ToolParameter {
            name: "Input Red Band File (optional; only if colour-composite not specified)"
                .to_owned(),
            flags: vec!["--red".to_owned()],
            description:
                "Input red band image file. Optionally specified if colour-composite not specified."
                    .to_owned(),
            parameter_type: ParameterType::ExistingFile(ParameterFileType::Raster),
            default_value: None,
            optional: true,
        });

        parameters.push(ToolParameter{
            name: "Input Green Band File (optional; only if colour-composite not specified)".to_owned(), 
            flags: vec!["--green".to_owned()], 
            description: "Input green band image file. Optionally specified if colour-composite not specified.".to_owned(),
            parameter_type: ParameterType::ExistingFile(ParameterFileType::Raster),
            default_value: None,
            optional: true
        });

        parameters.push(ToolParameter{
            name: "Input Blue Band File (optional; only if colour-composite not specified)".to_owned(), 
            flags: vec!["--blue".to_owned()], 
            description: "Input blue band image file. Optionally specified if colour-composite not specified.".to_owned(),
            parameter_type: ParameterType::ExistingFile(ParameterFileType::Raster),
            default_value: None,
            optional: true
        });

        parameters.push(ToolParameter{
            name: "Input Colour-Composite Image File (optional; only if individual bands not specified)".to_owned(), 
            flags: vec!["--composite".to_owned()], 
            description: "Input colour-composite image file. Only used if individual bands are not specified.".to_owned(),
            parameter_type: ParameterType::ExistingFile(ParameterFileType::Raster),
            default_value: None,
            optional: true
        });

        parameters.push(ToolParameter {
            name: "Input Panchromatic Band File".to_owned(),
            flags: vec!["--pan".to_owned()],
            description: "Input panchromatic band file.".to_owned(),
            parameter_type: ParameterType::ExistingFile(ParameterFileType::Raster),
            default_value: None,
            optional: false,
        });

        parameters.push(ToolParameter {
            name: "Output Colour Composite File".to_owned(),
            flags: vec!["-o".to_owned(), "--output".to_owned()],
            description: "Output colour composite file.".to_owned(),
            parameter_type: ParameterType::NewFile(ParameterFileType::Raster),
            default_value: None,
            optional: false,
        });

        parameters.push(ToolParameter {
            name: "Pan-Sharpening Method".to_owned(),
            flags: vec!["--method".to_owned()],
            description: "Options include 'brovey' (default) and 'ihs'".to_owned(),
            parameter_type: ParameterType::OptionList(vec!["brovey".to_owned(), "ihs".to_owned()]),
            default_value: Some("brovey".to_owned()),
            optional: true,
        });

        let sep: String = path::MAIN_SEPARATOR.to_string();
        let p = format!("{}", env::current_dir().unwrap().display());
        let e = format!("{}", env::current_exe().unwrap().display());
        let mut short_exe = e
            .replace(&p, "")
            .replace(".exe", "")
            .replace(".", "")
            .replace(&sep, "");
        if e.contains(".exe") {
            short_exe += ".exe";
        }
        let usage = format!(">>.*{0} -r={1} -v --wd=\"*path*to*data*\" --red=red.tif --green=green.tif --blue=blue.tif --pan=pan.tif --output=pan_sharp.tif --method='brovey'
>>.*{0} -r={1} -v --wd=\"*path*to*data*\" --composite=image.tif --pan=pan.tif --output=pan_sharp.tif --method='ihs'", short_exe, name).replace("*", &sep);

        PanchromaticSharpening {
            name: name,
            description: description,
            toolbox: toolbox,
            parameters: parameters,
            example_usage: usage,
        }
    }
}

impl WhiteboxTool for PanchromaticSharpening {
    fn get_source_file(&self) -> String {
        String::from(file!())
    }

    fn get_tool_name(&self) -> String {
        self.name.clone()
    }

    fn get_tool_description(&self) -> String {
        self.description.clone()
    }

    fn get_tool_parameters(&self) -> String {
        let mut s = String::from("{\"parameters\": [");
        for i in 0..self.parameters.len() {
            if i < self.parameters.len() - 1 {
                s.push_str(&(self.parameters[i].to_string()));
                s.push_str(",");
            } else {
                s.push_str(&(self.parameters[i].to_string()));
            }
        }
        s.push_str("]}");
        s
    }

    fn get_example_usage(&self) -> String {
        self.example_usage.clone()
    }

    fn get_toolbox(&self) -> String {
        self.toolbox.clone()
    }

    fn run<'a>(
        &self,
        args: Vec<String>,
        working_directory: &'a str,
        verbose: bool,
    ) -> Result<(), Error> {
        let mut red_file = String::new();
        let mut green_file = String::new();
        let mut blue_file = String::new();
        let mut composite_file = String::new();
        let mut use_composite = false;
        let mut pan_file = String::new();
        let mut output_file = String::new();
        let mut fusion_method = String::from("brovey");

        if args.len() == 0 {
            return Err(Error::new(
                ErrorKind::InvalidInput,
                "Tool run with no parameters.",
            ));
        }
        for i in 0..args.len() {
            let mut arg = args[i].replace("\"", "");
            arg = arg.replace("\'", "");
            let cmd = arg.split("="); // in case an equals sign was used
            let vec = cmd.collect::<Vec<&str>>();
            let mut keyval = false;
            if vec.len() > 1 {
                keyval = true;
            }
            let flag_val = vec[0].to_lowercase().replace("--", "-");
            if flag_val == "-r" || flag_val == "-red" {
                red_file = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
            } else if flag_val == "-g" || flag_val == "-green" {
                green_file = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
            } else if flag_val == "-b" || flag_val == "-blue" {
                blue_file = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
            } else if flag_val == "-p" || flag_val == "-pan" {
                pan_file = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
            } else if flag_val == "-c" || flag_val == "-composite" {
                composite_file = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
                use_composite = true;
            } else if flag_val == "-o" || flag_val == "-output" {
                output_file = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
            } else if flag_val == "-method" {
                fusion_method = if keyval {
                    vec[1].to_string()
                } else {
                    args[i + 1].to_string()
                };
                fusion_method = if fusion_method.to_lowercase().contains("bro") {
                    String::from("brovey")
                } else {
                    String::from("ihs")
                };
            }
        }

        if verbose {
            println!("***************{}", "*".repeat(self.get_tool_name().len()));
            println!("* Welcome to {} *", self.get_tool_name());
            println!("***************{}", "*".repeat(self.get_tool_name().len()));
        }

        let sep: String = path::MAIN_SEPARATOR.to_string();

        let mut progress: usize;
        let mut old_progress: usize = 1;

        if !red_file.contains(&sep) && !red_file.contains("/") {
            red_file = format!("{}{}", working_directory, red_file);
        }
        if !green_file.contains(&sep) && !green_file.contains("/") {
            green_file = format!("{}{}", working_directory, green_file);
        }
        if !blue_file.contains(&sep) && !blue_file.contains("/") {
            blue_file = format!("{}{}", working_directory, blue_file);
        }
        if !composite_file.contains(&sep) && !composite_file.contains("/") {
            composite_file = format!("{}{}", working_directory, composite_file);
        }
        if !pan_file.contains(&sep) && !pan_file.contains("/") {
            pan_file = format!("{}{}", working_directory, pan_file);
        }
        if !output_file.contains(&sep) && !output_file.contains("/") {
            output_file = format!("{}{}", working_directory, output_file);
        }

        let num_procs = num_cpus::get() as isize;

        let mut input: Array2D<f64>;
        let rows_ms: isize;
        let columns_ms: isize;
        let mut nodata_ms = 0f64;
        let north: f64;
        let west: f64;
        let resolution_x: f64;
        let resolution_y: f64;

        if use_composite {
            if verbose {
                println!("Reading multispec image data...")
            };
            let input_c = Raster::new(&composite_file, "r")?;

            rows_ms = input_c.configs.rows as isize;
            columns_ms = input_c.configs.columns as isize;
            nodata_ms = input_c.configs.nodata;

            north = input_c.configs.north;
            west = input_c.configs.west;
            resolution_x = input_c.configs.resolution_x;
            resolution_y = input_c.configs.resolution_y;

            input = input_c.get_data_as_array2d();
        } else {
            if verbose {
                println!("Reading red band data...")
            };
            let input_r = Raster::new(&red_file, "r")?;
            if verbose {
                println!("Reading green band data...")
            };
            let input_g = Raster::new(&green_file, "r")?;
            if verbose {
                println!("Reading blue band data...")
            };
            let input_b = Raster::new(&blue_file, "r")?;

            // make sure the input files have the same size
            if input_r.configs.rows != input_g.configs.rows
                || input_r.configs.columns != input_g.configs.columns
            {
                return Err(Error::new(ErrorKind::InvalidInput,
                                      "The input files must have the same number of rows and columns and spatial extent."));
            }
            if input_r.configs.rows != input_b.configs.rows
                || input_r.configs.columns != input_b.configs.columns
            {
                return Err(Error::new(ErrorKind::InvalidInput,
                                      "The input files must have the same number of rows and columns and spatial extent."));
            }

            let nodata_r = input_r.configs.nodata;
            let nodata_g = input_g.configs.nodata;
            let nodata_b = input_b.configs.nodata;

            rows_ms = input_r.configs.rows as isize;
            columns_ms = input_r.configs.columns as isize;

            north = input_r.configs.north;
            west = input_r.configs.west;
            resolution_x = input_r.configs.resolution_x;
            resolution_y = input_r.configs.resolution_y;

            input = Array2D::new(rows_ms, columns_ms, nodata_ms, nodata_ms)?; // : Array2D<f64>
            let (mut r, mut g, mut b): (f64, f64, f64);
            let (mut r_out, mut g_out, mut b_out): (u32, u32, u32);
            let r_min = input_r.configs.display_min;
            let r_range = input_r.configs.display_max - input_r.configs.display_min;
            let g_min = input_g.configs.display_min;
            let g_range = input_g.configs.display_max - input_g.configs.display_min;
            let b_min = input_b.configs.display_min;
            let b_range = input_b.configs.display_max - input_b.configs.display_min;
            for row in 0..rows_ms {
                for col in 0..columns_ms {
                    r = input_r[(row, col)];
                    g = input_g[(row, col)];
                    b = input_b[(row, col)];
                    if r != nodata_r && g != nodata_g && b != nodata_b {
                        r = (r - r_min) / r_range * 255f64;
                        if r < 0f64 {
                            r = 0f64;
                        }
                        if r > 255f64 {
                            r = 255f64;
                        }
                        r_out = r as u32;

                        g = (g - g_min) / g_range * 255f64;
                        if g < 0f64 {
                            g = 0f64;
                        }
                        if g > 255f64 {
                            g = 255f64;
                        }
                        g_out = g as u32;

                        b = (b - b_min) / b_range * 255f64;
                        if b < 0f64 {
                            b = 0f64;
                        }
                        if b > 255f64 {
                            b = 255f64;
                        }
                        b_out = b as u32;

                        input[(row, col)] =
                            ((255 << 24) | (b_out << 16) | (g_out << 8) | r_out) as f64;
                    }
                }
            }
        }

        let input = Arc::new(input);

        if verbose {
            println!("Reading pan image data...")
        };
        let pan = Arc::new(Raster::new(&pan_file, "r")?);
        let rows_pan = pan.configs.rows as isize;
        let columns_pan = pan.configs.columns as isize;
        let nodata_pan = pan.configs.nodata;
        let pan_min = pan.configs.display_min;
        let pan_range = pan.configs.display_max - pan.configs.display_min;

        let start = Instant::now();

        let mut output = Raster::initialize_using_file(&output_file, &pan);
        output.configs.photometric_interp = PhotometricInterpretation::RGB;
        output.configs.data_type = DataType::RGBA32;
        let nodata_out = 0f64;
        output.reinitialize_values(nodata_out);

        if fusion_method == String::from("brovey") {
            let (tx, rx) = mpsc::channel();
            for tid in 0..num_procs {
                let pan = pan.clone();
                let input = input.clone();
                let tx = tx.clone();
                thread::spawn(move || {
                    let get_column_from_x =
                        |x: f64| -> isize { ((x - west) / resolution_x).floor() as isize };
                    let get_row_from_y =
                        |y: f64| -> isize { ((north - y) / resolution_y).floor() as isize };
                    let mut p: f64;
                    let mut adj: f64;
                    let (mut r, mut g, mut b): (f64, f64, f64);
                    let (mut r_out, mut g_out, mut b_out): (u32, u32, u32);
                    let (mut x, mut y): (f64, f64);
                    let (mut source_col, mut source_row): (isize, isize);
                    let (mut z_ms, mut z_pan): (f64, f64);
                    for row in (0..rows_pan).filter(|row_val| row_val % num_procs == tid) {
                        y = pan.get_y_from_row(row);
                        source_row = get_row_from_y(y);
                        let mut data = vec![nodata_out; columns_pan as usize];
                        for col in 0..columns_pan {
                            x = pan.get_x_from_column(col);
                            source_col = get_column_from_x(x);
                            z_pan = pan[(row, col)];
                            z_ms = input[(source_row, source_col)];

                            if z_ms != nodata_ms && z_pan != nodata_pan {
                                p = (z_pan - pan_min) / pan_range;
                                if p < 0f64 {
                                    p = 0f64;
                                }
                                if p > 1f64 {
                                    p = 1f64;
                                }

                                r = (z_ms as u32 & 0xFF) as f64;
                                g = ((z_ms as u32 >> 8) & 0xFF) as f64;
                                b = ((z_ms as u32 >> 16) & 0xFF) as f64;

                                adj = (r + g + b) / 3f64;

                                r_out = (r * p / adj * 255f64) as u32;
                                g_out = (g * p / adj * 255f64) as u32;
                                b_out = (b * p / adj * 255f64) as u32;

                                if r_out > 255 {
                                    r_out = 255;
                                }
                                if g_out > 255 {
                                    g_out = 255;
                                }
                                if b_out > 255 {
                                    b_out = 255;
                                }

                                data[col as usize] =
                                    ((255 << 24) | (b_out << 16) | (g_out << 8) | r_out) as f64;
                            }
                        }
                        tx.send((row, data)).unwrap();
                    }
                });
            }

            for row in 0..rows_pan {
                let data = rx.recv().expect("Error receiving data from thread.");
                output.set_row_data(data.0, data.1);
                if verbose {
                    progress = (100.0_f64 * (row + 1) as f64 / rows_pan as f64) as usize;
                    if progress != old_progress {
                        println!("Progress: {}%", progress);
                        old_progress = progress;
                    }
                }
            }
        } else {
            // ihs

            // find the overall maximum in the ms data
            let (tx, rx) = mpsc::channel();
            for tid in 0..num_procs {
                let input = input.clone();
                let tx = tx.clone();
                thread::spawn(move || {
                    let mut overall_max = f64::NEG_INFINITY;
                    let (mut r, mut g, mut b): (f64, f64, f64);
                    let mut z: f64;
                    for row in (0..rows_ms).filter(|row_val| row_val % num_procs == tid) {
                        for col in 0..columns_ms {
                            z = input[(row, col)];
                            if z != nodata_ms {
                                r = (z as u32 & 0xFF) as f64;
                                g = ((z as u32 >> 8) & 0xFF) as f64;
                                b = ((z as u32 >> 16) & 0xFF) as f64;

                                if r > overall_max {
                                    overall_max = r;
                                }
                                if g > overall_max {
                                    overall_max = g;
                                }
                                if b > overall_max {
                                    overall_max = b;
                                }
                            }
                        }
                    }
                    tx.send(overall_max).unwrap();
                });
            }

            let mut overall_max = f64::NEG_INFINITY;
            for tid in 0..num_procs {
                let data = rx.recv().expect("Error receiving data from thread.");
                if data > overall_max {
                    overall_max = data;
                }
                if verbose {
                    progress = (100.0_f64 * tid as f64 / (num_procs - 1) as f64) as usize;
                    if progress != old_progress {
                        println!("Progress: {}%", progress);
                        old_progress = progress;
                    }
                }
            }

            let (tx, rx) = mpsc::channel();
            for tid in 0..num_procs {
                let pan = pan.clone();
                let input = input.clone();
                let tx = tx.clone();
                thread::spawn(move || {
                    let get_column_from_x =
                        |x: f64| -> isize { ((x - west) / resolution_x).floor() as isize };
                    let get_row_from_y =
                        |y: f64| -> isize { ((north - y) / resolution_y).floor() as isize };
                    let mut p: f64;
                    let mut min_rgb: f64;
                    let (mut r, mut g, mut b): (f64, f64, f64);
                    let (mut i, mut h, mut s): (f64, f64, f64);
                    let (mut r_out, mut g_out, mut b_out): (u32, u32, u32);
                    let (mut x, mut y): (f64, f64);
                    let (mut source_col, mut source_row): (isize, isize);
                    let (mut z_ms, mut z_pan): (f64, f64);
                    for row in (0..rows_pan).filter(|row_val| row_val % num_procs == tid) {
                        y = pan.get_y_from_row(row);
                        source_row = get_row_from_y(y);
                        let mut data = vec![nodata_out; columns_pan as usize];
                        for col in 0..columns_pan {
                            x = pan.get_x_from_column(col);
                            source_col = get_column_from_x(x);
                            z_pan = pan[(row, col)];
                            z_ms = input[(source_row, source_col)];
                            if z_ms != nodata_ms && z_pan != nodata_pan {
                                p = (z_pan - pan_min) / pan_range;
                                if p < 0f64 {
                                    p = 0f64;
                                }
                                if p > 1f64 {
                                    p = 1f64;
                                }

                                r = (z_ms as u32 & 0xFF) as f64 / overall_max;
                                g = ((z_ms as u32 >> 8) & 0xFF) as f64 / overall_max;
                                b = ((z_ms as u32 >> 16) & 0xFF) as f64 / overall_max;

                                if r != g || g != b {
                                    // RGB to IHS transformation
                                    i = r + g + b;

                                    min_rgb = r.min(g).min(b);
                                    h = if i == 3f64 {
                                        0f64
                                    } else if b == min_rgb {
                                        (g - b) / (i - 3f64 * b)
                                    } else if r == min_rgb {
                                        (b - r) / (i - 3f64 * r) + 1f64
                                    } else {
                                        //g == min_rgb
                                        (r - g) / (i - 3f64 * g) + 2f64
                                    };

                                    s = if h <= 1f64 {
                                        (i - 3f64 * b) / i
                                    } else if h <= 2f64 {
                                        (i - 3f64 * r) / i
                                    } else {
                                        // h <= 3f64
                                        (i - 3f64 * g) / i
                                    };

                                    // update i for the panchromatic value
                                    i = p * 3f64;

                                    // IHS to RGB transformation
                                    if h <= 1f64 {
                                        r = i * (1f64 + 2f64 * s - 3f64 * s * h) / 3f64;
                                        g = i * (1f64 - s + 3f64 * s * h) / 3f64;
                                        b = i * (1f64 - s) / 3f64;
                                    } else if h <= 2f64 {
                                        r = i * (1f64 - s) / 3f64;
                                        g = i * (1f64 + 2f64 * s - 3f64 * s * (h - 1f64)) / 3f64;
                                        b = i * (1f64 - s + 3f64 * s * (h - 1f64)) / 3f64;
                                    } else {
                                        // h <= 3f64
                                        r = i * (1f64 - s + 3f64 * s * (h - 2f64)) / 3f64;
                                        g = i * (1f64 - s) / 3f64;
                                        b = i * (1f64 + 2f64 * s - 3f64 * s * (h - 2f64)) / 3f64;
                                    }
                                } else {
                                    r *= p;
                                    g *= p;
                                    b *= p;
                                }

                                r_out = (r * 255f64) as u32;
                                g_out = (g * 255f64) as u32;
                                b_out = (b * 255f64) as u32;

                                if r_out > 255 {
                                    r_out = 255;
                                }
                                if g_out > 255 {
                                    g_out = 255;
                                }
                                if b_out > 255 {
                                    b_out = 255;
                                }

                                data[col as usize] =
                                    ((255 << 24) | (b_out << 16) | (g_out << 8) | r_out) as f64;
                            }
                        }
                        tx.send((row, data)).unwrap();
                    }
                });
            }

            for row in 0..rows_pan {
                let data = rx.recv().expect("Error receiving data from thread.");
                output.set_row_data(data.0, data.1);
                if verbose {
                    progress = (100.0_f64 * row as f64 / (rows_pan - 1) as f64) as usize;
                    if progress != old_progress {
                        println!("Progress: {}%", progress);
                        old_progress = progress;
                    }
                }
            }
        }

        let elapsed_time = get_formatted_elapsed_time(start);

        output.add_metadata_entry(format!(
            "Created by whitebox_tools\' {} tool",
            self.get_tool_name()
        ));
        if use_composite {
            output.add_metadata_entry(format!("Input colour composite file: {}", composite_file));
        } else {
            output.add_metadata_entry(format!("Input red-band file: {}", red_file));
            output.add_metadata_entry(format!("Input green-band file: {}", green_file));
            output.add_metadata_entry(format!("Input blue-band file: {}", blue_file));
        }
        output.add_metadata_entry(format!("Input panchromatic file: {}", pan_file));
        output.add_metadata_entry(format!("Pan-sharpening fusion method: {}", fusion_method));
        output.add_metadata_entry(format!("Elapsed Time (excluding I/O): {}", elapsed_time));

        if verbose {
            println!("Saving data...")
        };
        let _ = match output.write() {
            Ok(_) => {
                if verbose {
                    println!("Output file written")
                }
            }
            Err(e) => return Err(e),
        };
        if verbose {
            println!(
                "{}",
                &format!("Elapsed Time (excluding I/O): {}", elapsed_time)
            );
        }

        Ok(())
    }
}
