/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under both the MIT license found in the
 * LICENSE-MIT file in the root directory of this source tree and the Apache
 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
 * of this source tree.
 */

use std::collections::HashSet;
use std::sync::OnceLock;

// The erlang compiler sys_core_fold.erl file has a function
// `is_auto_imported/3` which calls `erl_internal:bif/2`.
// This list is derived from `erl_internal:bif/2`.
pub fn erlang_funs() -> &'static HashSet<(&'static str, usize)> {
    static ERLANG_FUNS: OnceLock<HashSet<(&str, usize)>> = OnceLock::new();
    ERLANG_FUNS.get_or_init(|| {
        HashSet::from([
            ("abs", 1),
            ("alias", 0),
            ("alias", 1),
            ("apply", 2),
            ("apply", 3),
            ("atom_to_binary", 1),
            ("atom_to_binary", 2),
            ("atom_to_list", 1),
            ("binary_part", 2),
            ("binary_part", 3),
            ("binary_to_atom", 1),
            ("binary_to_atom", 2),
            ("binary_to_existing_atom", 1),
            ("binary_to_existing_atom", 2),
            ("binary_to_integer", 1),
            ("binary_to_integer", 2),
            ("binary_to_float", 1),
            ("binary_to_list", 1),
            ("binary_to_list", 3),
            ("binary_to_term", 1),
            ("binary_to_term", 2),
            ("bitsize", 1),
            ("bit_size", 1),
            ("bitstring_to_list", 1),
            ("byte_size", 1),
            ("ceil", 1),
            ("check_old_code", 1),
            ("check_process_code", 2),
            ("check_process_code", 3),
            ("date", 0),
            ("delete_module", 1),
            ("demonitor", 1),
            ("demonitor", 2),
            ("disconnect_node", 1),
            ("element", 2),
            ("erase", 0),
            ("erase", 1),
            ("error", 1),
            ("error", 2),
            ("error", 3),
            ("exit", 1),
            ("exit", 2),
            ("float", 1),
            ("float_to_list", 1),
            ("float_to_list", 2),
            ("float_to_binary", 1),
            ("float_to_binary", 2),
            ("floor", 1),
            ("garbage_collect", 0),
            ("garbage_collect", 1),
            ("garbage_collect", 2),
            ("get", 0),
            ("get", 1),
            ("get_keys", 0),
            ("get_keys", 1),
            ("group_leader", 0),
            ("group_leader", 2),
            ("halt", 0),
            ("halt", 1),
            ("halt", 2),
            ("hd", 1),
            ("integer_to_binary", 1),
            ("integer_to_binary", 2),
            ("integer_to_list", 1),
            ("integer_to_list", 2),
            ("iolist_size", 1),
            ("iolist_to_binary", 1),
            ("is_alive", 0),
            ("is_process_alive", 1),
            ("is_atom", 1),
            ("is_boolean", 1),
            ("is_binary", 1),
            ("is_bitstring", 1),
            ("is_float", 1),
            ("is_function", 1),
            ("is_function", 2),
            ("is_integer", 1),
            ("is_list", 1),
            ("is_map", 1),
            ("is_map_key", 2),
            ("is_number", 1),
            ("is_pid", 1),
            ("is_port", 1),
            ("is_reference", 1),
            ("is_tuple", 1),
            ("is_record", 2),
            ("is_record", 3),
            ("length", 1),
            ("link", 1),
            ("list_to_atom", 1),
            ("list_to_binary", 1),
            ("list_to_bitstring", 1),
            ("list_to_existing_atom", 1),
            ("list_to_float", 1),
            ("list_to_integer", 1),
            ("list_to_integer", 2),
            ("list_to_pid", 1),
            ("list_to_port", 1),
            ("list_to_ref", 1),
            ("list_to_tuple", 1),
            ("load_module", 2),
            ("make_ref", 0),
            ("map_size", 1),
            ("map_get", 2),
            ("max", 2),
            ("min", 2),
            ("module_loaded", 1),
            ("monitor", 2),
            ("monitor", 3),
            ("monitor_node", 2),
            ("node", 0),
            ("node", 1),
            ("nodes", 0),
            ("nodes", 1),
            ("now", 0),
            ("open_port", 2),
            ("pid_to_list", 1),
            ("port_to_list", 1),
            ("port_close", 1),
            ("port_command", 2),
            ("port_command", 3),
            ("port_connect", 2),
            ("port_control", 3),
            ("pre_loaded", 0),
            ("process_flag", 2),
            ("process_flag", 3),
            ("process_info", 1),
            ("process_info", 2),
            ("processes", 0),
            ("purge_module", 1),
            ("put", 2),
            ("ref_to_list", 1),
            ("register", 2),
            ("registered", 0),
            ("round", 1),
            ("self", 0),
            ("setelement", 3),
            ("size", 1),
            ("spawn", 1),
            ("spawn", 2),
            ("spawn", 3),
            ("spawn", 4),
            ("spawn_link", 1),
            ("spawn_link", 2),
            ("spawn_link", 3),
            ("spawn_link", 4),
            ("spawn_request", 1),
            ("spawn_request", 2),
            ("spawn_request", 3),
            ("spawn_request", 4),
            ("spawn_request", 5),
            ("spawn_request_abandon", 1),
            ("spawn_monitor", 1),
            ("spawn_monitor", 2),
            ("spawn_monitor", 3),
            ("spawn_monitor", 4),
            ("spawn_opt", 2),
            ("spawn_opt", 3),
            ("spawn_opt", 4),
            ("spawn_opt", 5),
            ("split_binary", 2),
            ("statistics", 1),
            ("term_to_binary", 1),
            ("term_to_binary", 2),
            ("term_to_iovec", 1),
            ("term_to_iovec", 2),
            ("throw", 1),
            ("time", 0),
            ("tl", 1),
            ("trunc", 1),
            ("tuple_size", 1),
            ("tuple_to_list", 1),
            ("unalias", 1),
            ("unlink", 1),
            ("unregister", 1),
            ("whereis", 1),
        ])
    })
}

/// This module provides a check for auto-imported functions from the `erlang` module
/// module.
pub fn in_erlang_module(f: &str, a: usize) -> bool {
    is_erlang_fun(f, a) || is_erlang_type(f, a)
}

pub fn is_erlang_fun(f: &str, a: usize) -> bool {
    erlang_funs().contains(&(f, a))
}

pub fn is_erlang_type(f: &str, a: usize) -> bool {
    match (f, a) {
        // This part is from https://www.erlang.org/doc/reference_manual/typespec.html
        ("any", 0) => true,
        ("atom", 0) => true,
        ("float", 0) => true,
        ("fun", 1) => true,
        ("fun", 0) => true,
        ("integer", 0) => true,
        ("list", 1) => true,
        ("maybe_improper_list", 2) => true,
        ("none", 0) => true,
        ("nonempty_improper_list", 2) => true,
        ("nonempty_list", 1) => true,
        ("pid", 0) => true,
        ("port", 0) => true,
        ("reference", 0) => true,
        ("tuple", 0) => true,

        // This part is from table 7.1 of https://www.erlang.org/doc/reference_manual/typespec.html
        ("term", 0) => true,
        ("binary", 0) => true,
        ("nonempty_binary", 0) => true,
        ("bitstring", 0) => true,
        ("nonempty_bitstring", 0) => true,
        ("boolean", 0) => true,
        ("byte", 0) => true,
        ("char", 0) => true,
        ("nil", 0) => true,
        ("number", 0) => true,
        ("list", 0) => true,
        ("maybe_improper_list", 0) => true,
        ("nonempty_list", 0) => true,
        ("string", 0) => true,
        ("nonempty_string", 0) => true,
        ("iodata", 0) => true,
        ("iolist", 0) => true,
        ("map", 0) => true,
        ("function", 0) => true,
        ("module", 0) => true,
        ("mfa", 0) => true,
        ("arity", 0) => true,
        ("identifier", 0) => true,
        ("node", 0) => true,
        ("timeout", 0) => true,
        ("no_return", 0) => true,

        ("dynamic", 0) => true,
        _ => false,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_is_erlang_fun() {
        assert!(is_erlang_fun("abs", 1));
        assert!(!is_erlang_fun("foo", 1));
    }
}
