package functions

import (
	"encoding/json"
	"fmt"
	"log"
	"net"
	"os"
	"strings"
	"time"

	mqtt "github.com/eclipse/paho.mqtt.golang"
	"github.com/gravitl/netclient/config"
	"github.com/gravitl/netclient/daemon"
	"github.com/gravitl/netclient/ncutils"
	"github.com/gravitl/netclient/networking"
	proxyCfg "github.com/gravitl/netclient/nmproxy/config"
	"github.com/gravitl/netclient/nmproxy/turn"
	"github.com/gravitl/netclient/routes"
	"github.com/gravitl/netclient/wireguard"
	"github.com/gravitl/netmaker/logger"
	"github.com/gravitl/netmaker/models"
	"github.com/gravitl/txeh"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

// MQTimeout - time out for mqtt connections
const MQTimeout = 30

// All -- mqtt message hander for all ('#') topics
var All mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
	logger.Log(0, "default message handler -- received message but not handling")
	logger.Log(0, "topic: "+string(msg.Topic()))
}

// NodeUpdate -- mqtt message handler for /update/<NodeID> topic
func NodeUpdate(client mqtt.Client, msg mqtt.Message) {
	network := parseNetworkFromTopic(msg.Topic())
	logger.Log(0, "processing node update for network", network)
	node := config.GetNode(network)
	server := config.Servers[node.Server]
	data, err := decryptMsg(server.Name, msg.Payload())
	if err != nil {
		logger.Log(0, "error decrypting message", err.Error())
		return
	}
	serverNode := models.Node{}
	if err = json.Unmarshal([]byte(data), &serverNode); err != nil {
		logger.Log(0, "error unmarshalling node update data"+err.Error())
		return
	}
	newNode := config.Node{}
	newNode.CommonNode = serverNode.CommonNode

	// see if cache hit, if so skip
	var currentMessage = read(newNode.Network, lastNodeUpdate)
	if currentMessage == string(data) {
		logger.Log(3, "cache hit on node update ... skipping")
		return
	}
	insert(newNode.Network, lastNodeUpdate, string(data)) // store new message in cache
	logger.Log(0, "network:", newNode.Network, "received message to update node "+newNode.ID.String())
	// check if interface needs to delta
	ifaceDelta := wireguard.IfaceDelta(&node, &newNode)
	keepaliveChange := node.PersistentKeepalive != newNode.PersistentKeepalive
	//nodeCfg.Node = newNode
	switch newNode.Action {
	case models.NODE_DELETE:
		logger.Log(0, "network:", newNode.Network, "received delete request for", newNode.ID.String())
		unsubscribeNode(client, &newNode)
		if _, err = LeaveNetwork(newNode.Network, true); err != nil {
			if !strings.Contains("rpc error", err.Error()) {
				logger.Log(0, "failed to leave, please check that local files for network", newNode.Network, "were removed")
				return
			}
		}
		logger.Log(0, newNode.ID.String(), "was removed from network", newNode.Network)
		return
	case models.NODE_FORCE_UPDATE:
		ifaceDelta = true
	case models.NODE_NOOP:
	default:
	}
	// Save new config
	newNode.Action = models.NODE_NOOP
	config.UpdateNodeMap(network, newNode)
	if err := config.WriteNodeConfig(); err != nil {
		logger.Log(0, newNode.Network, "error updating node configuration: ", err.Error())
	}
	nc := wireguard.NewNCIface(config.Netclient(), config.GetNodes())
	if err := nc.Configure(); err != nil {
		logger.Log(0, "could not configure netmaker interface", err.Error())
		return
	}

	if err := wireguard.UpdateWgInterface(&newNode, config.Netclient()); err != nil {

		logger.Log(0, "error updating wireguard config "+err.Error())
		return
	}
	if keepaliveChange {
		wireguard.UpdateKeepAlive(int(newNode.PersistentKeepalive.Seconds()))
	}
	time.Sleep(time.Second)
	if ifaceDelta { // if a change caused an ifacedelta we need to notify the server to update the peers
		doneErr := publishSignal(&newNode, DONE)
		if doneErr != nil {
			logger.Log(0, "network:", newNode.Network, "could not notify server to update peers after interface change")
		} else {
			logger.Log(0, "network:", newNode.Network, "signalled finished interface update to server")
		}
	}
}

// HostPeerUpdate - mq handler for host peer update peers/host/<HOSTID>/<SERVERNAME>
func HostPeerUpdate(client mqtt.Client, msg mqtt.Message) {
	var peerUpdate models.HostPeerUpdate
	var err error
	if len(config.GetNodes()) == 0 {
		logger.Log(3, "skipping unwanted peer update, no nodes exist")
		return
	}
	serverName := parseServerFromTopic(msg.Topic())
	server := config.GetServer(serverName)
	if server == nil {
		logger.Log(0, "server ", serverName, " not found in config")
		return
	}
	logger.Log(3, "received peer update for host from: ", serverName)
	data, err := decryptMsg(serverName, msg.Payload())
	if err != nil {
		return
	}
	err = json.Unmarshal([]byte(data), &peerUpdate)
	if err != nil {
		logger.Log(0, "error unmarshalling peer data")
		return
	}
	if peerUpdate.ServerVersion != config.Version {
		logger.Log(0, "server/client version mismatch server: ", peerUpdate.ServerVersion, " client: ", config.Version)
		if versionLessThan(config.Version, peerUpdate.ServerVersion) && config.Netclient().Host.AutoUpdate {
			if err := UseVersion(peerUpdate.ServerVersion, false); err != nil {
				logger.Log(0, "error updating client to server's version", err.Error())
			} else {
				logger.Log(0, "updated client to server's version: ", peerUpdate.ServerVersion, " ,restart daemon to reflect changes")
			}
		}
	}
	if peerUpdate.ServerVersion != server.Version {
		logger.Log(1, "updating server version")
		server.Version = peerUpdate.ServerVersion
		config.WriteServerConfig()
	}
	_, err = wireguard.UpdateWgPeers()
	if err != nil {
		logger.Log(0, "error updating wireguard peers"+err.Error())
		return
	}

	gwDetected := config.GW4PeerDetected || config.GW6PeerDetected
	currentGW4 := config.GW4Addr
	currentGW6 := config.GW6Addr
	isInetGW := config.UpdateHostPeers(serverName, peerUpdate.Peers)
	_ = config.WriteNetclientConfig()
	_ = wireguard.SetPeers()
	wireguard.GetInterface().GetPeerRoutes()
	if err = routes.SetNetmakerPeerEndpointRoutes(config.Netclient().DefaultInterface); err != nil {
		logger.Log(0, "error when setting peer routes after peer update", err.Error())
	}
	_ = wireguard.GetInterface().ApplyAddrs(true)
	gwDelta := (currentGW4.IP != nil && !currentGW4.IP.Equal(config.GW4Addr.IP)) ||
		(currentGW6.IP != nil && !currentGW6.IP.Equal(config.GW6Addr.IP))
	originalGW := currentGW4
	if originalGW.IP != nil {
		originalGW = currentGW6
	}
	handlePeerInetGateways(
		gwDetected,
		isInetGW,
		gwDelta,
		&originalGW,
	)

	go handleEndpointDetection(&peerUpdate)
	if proxyCfg.GetCfg().IsProxyRunning() {
		time.Sleep(time.Second * 2) // sleep required to avoid race condition
		ProxyManagerChan <- &peerUpdate
	}

}

// HostUpdate - mq handler for host update host/update/<HOSTID>/<SERVERNAME>
func HostUpdate(client mqtt.Client, msg mqtt.Message) {
	var hostUpdate models.HostUpdate
	var err error
	serverName := parseServerFromTopic(msg.Topic())
	server := config.GetServer(serverName)
	if server == nil {
		logger.Log(0, "server ", serverName, " not found in config")
		return
	}
	data, err := decryptMsg(serverName, msg.Payload())
	if err != nil {
		return
	}
	err = json.Unmarshal([]byte(data), &hostUpdate)
	if err != nil {
		logger.Log(0, "error unmarshalling host update data")
		return
	}
	logger.Log(3, fmt.Sprintf("---> received host update [ action: %v ] for host from %s ", hostUpdate.Action, serverName))
	var resetInterface, restartDaemon, clearMsg bool
	switch hostUpdate.Action {
	case models.JoinHostToNetwork:
		commonNode := hostUpdate.Node.CommonNode
		nodeCfg := config.Node{
			CommonNode: commonNode,
		}
		config.UpdateNodeMap(hostUpdate.Node.Network, nodeCfg)
		server := config.GetServer(serverName)
		if server == nil {
			return
		}
		server.Nodes[hostUpdate.Node.Network] = true
		config.UpdateServer(serverName, *server)
		config.WriteNodeConfig()
		config.WriteServerConfig()
		logger.Log(1, "added node for network", hostUpdate.Node.Network, "on server", serverName)
		clearRetainedMsg(client, msg.Topic()) // clear message before ACK
		if err = PublishHostUpdate(serverName, models.Acknowledgement); err != nil {
			logger.Log(0, "failed to response with ACK to server", serverName)
		}
		restartDaemon = true
	case models.DeleteHost:
		clearRetainedMsg(client, msg.Topic())
		unsubscribeHost(client, serverName)
		deleteHostCfg(client, serverName)
		config.WriteNodeConfig()
		config.WriteServerConfig()
		resetInterface = true
	case models.UpdateHost:
		resetInterface, restartDaemon = updateHostConfig(&hostUpdate.Host)
		clearMsg = true
	case models.RequestAck:
		clearRetainedMsg(client, msg.Topic()) // clear message before ACK
		if err = PublishHostUpdate(serverName, models.Acknowledgement); err != nil {
			logger.Log(0, "failed to response with ACK to server", serverName, err.Error())
		}
	case models.SignalHost:
		turn.PeerSignalCh <- hostUpdate.Signal
	case models.UpdateKeys:
		clearRetainedMsg(client, msg.Topic()) // clear message
		UpdateKeys()
	default:
		logger.Log(1, "unknown host action")
		return
	}
	if err = config.WriteNetclientConfig(); err != nil {
		logger.Log(0, "failed to write host config -", err.Error())
		return
	}

	if restartDaemon {
		if clearMsg {
			clearRetainedMsg(client, msg.Topic())
		}
		if err := daemon.Restart(); err != nil {
			logger.Log(0, "failed to restart daemon: ", err.Error())
		}
		return
	}
	if resetInterface {
		nc := wireguard.GetInterface()
		nc.Close()
		nc = wireguard.NewNCIface(config.Netclient(), config.GetNodes())
		nc.Create()
		if err := nc.Configure(); err != nil {
			logger.Log(0, "could not configure netmaker interface", err.Error())
			return
		}

		if err = wireguard.SetPeers(); err == nil {
			if err = routes.SetNetmakerPeerEndpointRoutes(config.Netclient().DefaultInterface); err != nil {
				logger.Log(0, "error when setting peer routes after host update", err.Error())
			}
		}
	}
}

func handleEndpointDetection(peerUpdate *models.HostPeerUpdate) {
	hostPubKey := config.Netclient().PublicKey.String()
	// select best interface for each peer and set it as endpoint
	currentCidrs := getAllAllowedIPs(peerUpdate.Peers[:])
	for idx := range peerUpdate.Peers {
		peerPubKey := peerUpdate.Peers[idx].PublicKey.String()
		if peerInfo, ok := peerUpdate.HostNetworkInfo[peerPubKey]; ok {
			for i := range peerInfo.Interfaces {
				peerIface := peerInfo.Interfaces[i]
				peerIP := peerIface.Address.IP
				if peerUpdate.Peers[idx].Endpoint == nil || peerIP == nil {
					continue
				}
				// check to skip bridge network
				if ncutils.IsBridgeNetwork(peerIface.Name) {
					continue
				}
				if strings.Contains(peerIP.String(), "127.0.0.") ||
					peerIP.IsMulticast() ||
					(peerIP.IsLinkLocalUnicast() && strings.Count(peerIP.String(), ":") >= 2) ||
					peerUpdate.Peers[idx].Endpoint.IP.Equal(peerIP) ||
					isAddressInPeers(peerIP, currentCidrs) {
					continue
				}
				if err := networking.FindBestEndpoint(
					peerIP.String(),
					hostPubKey,
					peerPubKey,
					peerInfo.ProxyListenPort,
				); err != nil { // happens v often
					logger.Log(3, "failed to check for endpoint on peer", peerPubKey, err.Error())
				}
			}
		}
	}
}

func deleteHostCfg(client mqtt.Client, server string) {
	config.DeleteServerHostPeerCfg(server)
	nodes := config.GetNodes()
	for k, node := range nodes {
		node := node
		if node.Server == server {
			unsubscribeNode(client, &node)
			config.DeleteNode(k)
		}
	}
	config.DeleteServer(server)
	// delete mq client from ServerSet map
	delete(ServerSet, server)
}

func updateHostConfig(host *models.Host) (resetInterface, restart bool) {
	hostCfg := config.Netclient()
	if hostCfg == nil || host == nil {
		return
	}
	if (host.ListenPort != 0 && hostCfg.ListenPort != host.ListenPort) ||
		(host.ProxyListenPort != 0 && hostCfg.ProxyListenPort != host.ProxyListenPort) {
		restart = true
	}
	if host.MTU != 0 && hostCfg.MTU != host.MTU {
		resetInterface = true
	}
	// store password before updating
	host.HostPass = hostCfg.HostPass
	hostCfg.Host = *host
	config.UpdateNetclient(*hostCfg)
	config.WriteNetclientConfig()
	return
}

func parseNetworkFromTopic(topic string) string {
	return strings.Split(topic, "/")[2]
}

func parseServerFromTopic(topic string) string {
	return strings.Split(topic, "/")[3]
}

// dnsUpdate - mq handler for host update dns/<HOSTID>/server
func dnsUpdate(client mqtt.Client, msg mqtt.Message) {
	temp := os.TempDir()
	lockfile := temp + "/netclient-lock"
	if err := config.Lock(lockfile); err != nil {
		logger.Log(0, "could not create lock file", err.Error())
		return
	}
	defer config.Unlock(lockfile)
	var dns models.DNSUpdate
	serverName := parseServerFromTopic(msg.Topic())
	server := config.GetServer(serverName)
	if server == nil {
		logger.Log(0, "server ", serverName, " not found in config")
		return
	}
	data, err := decryptMsg(serverName, msg.Payload())
	if err != nil {
		return
	}
	if err := json.Unmarshal([]byte(data), &dns); err != nil {
		logger.Log(0, "error unmarshalling dns update")
	}
	if config.Netclient().Debug {
		log.Println("dnsUpdate received", dns)
	}
	var currentMessage = read("dns", lastDNSUpdate)
	if currentMessage == string(data) {
		logger.Log(3, "cache hit on dns update ... skipping")
		return
	}
	insert("dns", lastDNSUpdate, string(data))
	logger.Log(3, "received dns update for", dns.Name)
	applyDNSUpdate(dns)
}

func applyDNSUpdate(dns models.DNSUpdate) {
	if config.Netclient().Debug {
		log.Println(dns)
	}
	hosts, err := txeh.NewHostsDefault()
	if err != nil {
		logger.Log(0, "failed to read hosts file", err.Error())
		return
	}
	switch dns.Action {
	case models.DNSInsert:
		hosts.AddHost(dns.Address, dns.Name, etcHostsComment)
	case models.DNSDeleteByName:
		hosts.RemoveHost(dns.Name, etcHostsComment)
	case models.DNSDeleteByIP:
		hosts.RemoveAddress(dns.Address, etcHostsComment)
	case models.DNSReplaceName:
		ok, ip, _ := hosts.HostAddressLookup(dns.Name, txeh.IPFamilyV4, etcHostsComment)
		if !ok {
			logger.Log(2, "failed to find dns address for host", dns.Name)
			return
		}
		dns.Address = ip
		hosts.RemoveHost(dns.Name, etcHostsComment)
		hosts.AddHost(dns.Address, dns.NewName, etcHostsComment)
	case models.DNSReplaceIP:
		hosts.RemoveAddress(dns.Address, etcHostsComment)
		hosts.AddHost(dns.NewAddress, dns.Name, etcHostsComment)
	}
	if err := hosts.Save(); err != nil {
		logger.Log(0, "error saving hosts file", err.Error())
		return
	}
}

// dnsAll- mq handler for host update dnsall/<HOSTID>/server
func dnsAll(client mqtt.Client, msg mqtt.Message) {
	temp := os.TempDir()
	lockfile := temp + "/netclient-lock"
	if err := config.Lock(lockfile); err != nil {
		logger.Log(0, "could not create lock file", err.Error())
		return
	}
	defer config.Unlock(lockfile)
	var dns []models.DNSUpdate
	serverName := parseServerFromTopic(msg.Topic())
	server := config.GetServer(serverName)
	if server == nil {
		logger.Log(0, "server ", serverName, " not found in config")
		return
	}
	data, err := decryptMsg(serverName, msg.Payload())
	if err != nil {
		return
	}
	if err := json.Unmarshal([]byte(data), &dns); err != nil {
		logger.Log(0, "error unmarshalling dns update")
	}
	if config.Netclient().Debug {
		log.Println("all dns", dns)
	}
	var currentMessage = read("dnsall", lastALLDNSUpdate)
	logger.Log(3, "received initial dns")
	if currentMessage == string(data) {
		logger.Log(3, "cache hit on all dns ... skipping")
		if config.Netclient().Debug {
			log.Println("dns cache", currentMessage, string(data))
		}
		return
	}
	insert("dnsall", lastALLDNSUpdate, string(data))
	applyAllDNS(dns)
}

func applyAllDNS(dns []models.DNSUpdate) {
	hosts, err := txeh.NewHostsDefault()
	if err != nil {
		logger.Log(0, "failed to read hosts file", err.Error())
		return
	}
	for _, entry := range dns {
		if entry.Action != models.DNSInsert {
			logger.Log(0, "invalid dns actions", entry.Action.String())
			continue
		}
		hosts.AddHost(entry.Address, entry.Name, etcHostsComment)
	}

	if err := hosts.Save(); err != nil {
		logger.Log(0, "error saving hosts file", err.Error())
		return
	}
}

func getAllAllowedIPs(peers []wgtypes.PeerConfig) (cidrs []net.IPNet) {
	if len(peers) > 0 { // nil check
		for i := range peers {
			peer := peers[i]
			cidrs = append(cidrs, peer.AllowedIPs...)
		}
	}
	if cidrs == nil {
		cidrs = []net.IPNet{}
	}
	return
}

func isAddressInPeers(ip net.IP, cidrs []net.IPNet) bool {
	if len(cidrs) > 0 {
		for i := range cidrs {
			currCidr := cidrs[i]
			if currCidr.Contains(ip) {
				return true
			}
		}
	}
	return false
}

func handlePeerInetGateways(gwDetected, isHostInetGateway, gwDelta bool, originalGW *net.IPNet) { // isHostInetGateway indicates if host should worry about setting gateways
	if gwDelta { // handle switching gateway IP to other GW peer
		if config.GW4PeerDetected {
			if err := routes.RemoveDefaultGW(originalGW); err != nil {
				logger.Log(3, "failed to remove default gateway from peer", originalGW.String(), err.Error())
			}
			if err := routes.SetDefaultGateway(&config.GW4Addr); err != nil {
				logger.Log(3, "failed to change default gateway to peer", config.GW4Addr.String(), err.Error())
			}
		} else if config.GW6PeerDetected {
			if err := routes.SetDefaultGateway(&config.GW6Addr); err != nil {
				logger.Log(3, "failed to set default gateway to peer", config.GW4Addr.String(), err.Error())
			}
		}
	} else {
		if !gwDetected && config.GW4PeerDetected && !isHostInetGateway { // ipv4 gateways take priority
			if err := routes.SetDefaultGateway(&config.GW4Addr); err != nil {
				logger.Log(3, "failed to set default gateway to peer", config.GW4Addr.String(), err.Error())
			}
		} else if gwDetected && !config.GW4PeerDetected {
			if err := routes.RemoveDefaultGW(&config.GW4Addr); err != nil {
				logger.Log(3, "failed to remove default gateway to peer", config.GW4Addr.String())
			}
		} else if !gwDetected && config.GW6PeerDetected && !isHostInetGateway {
			if err := routes.SetDefaultGateway(&config.GW6Addr); err != nil {
				logger.Log(3, "failed to set default gateway to peer", config.GW6Addr.String())
			}
		} else if gwDetected && !config.GW6PeerDetected {
			if err := routes.RemoveDefaultGW(&config.GW6Addr); err != nil {
				logger.Log(3, "failed to remove default gateway to peer", config.GW6Addr.String())
			}
		}
	}
}
