// This file is dual licensed under the terms of the Apache License, Version
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

use crate::{cvt, cvt_p, OpenSSLResult};
use foreign_types_shared::{ForeignType, ForeignTypeRef};

pub enum AeadType {
    ChaCha20Poly1305,
}

foreign_types::foreign_type! {
    type CType = ffi::EVP_AEAD_CTX;
    fn drop = ffi::EVP_AEAD_CTX_free;

    pub struct AeadCtx;
    pub struct AeadCtxRef;
}

unsafe impl Sync for AeadCtx {}
unsafe impl Send for AeadCtx {}

impl AeadCtx {
    pub fn new(aead: AeadType, key: &[u8]) -> OpenSSLResult<AeadCtx> {
        let aead = match aead {
            AeadType::ChaCha20Poly1305 => unsafe { ffi::EVP_aead_chacha20_poly1305() },
        };

        unsafe {
            let ctx = cvt_p(ffi::EVP_AEAD_CTX_new(
                aead,
                key.as_ptr(),
                key.len(),
                ffi::EVP_AEAD_DEFAULT_TAG_LENGTH as usize,
            ))?;
            Ok(AeadCtx::from_ptr(ctx))
        }
    }
}

impl AeadCtxRef {
    pub fn encrypt(
        &self,
        data: &[u8],
        nonce: &[u8],
        ad: &[u8],
        out: &mut [u8],
    ) -> OpenSSLResult<()> {
        let mut out_len = out.len();
        unsafe {
            cvt(ffi::EVP_AEAD_CTX_seal(
                self.as_ptr(),
                out.as_mut_ptr(),
                &mut out_len,
                out.len(),
                nonce.as_ptr(),
                nonce.len(),
                data.as_ptr(),
                data.len(),
                ad.as_ptr(),
                ad.len(),
            ))?;
        }
        Ok(())
    }

    pub fn decrypt(
        &self,
        data: &[u8],
        nonce: &[u8],
        ad: &[u8],
        out: &mut [u8],
    ) -> OpenSSLResult<()> {
        let mut out_len = out.len();
        unsafe {
            cvt(ffi::EVP_AEAD_CTX_open(
                self.as_ptr(),
                out.as_mut_ptr(),
                &mut out_len,
                out.len(),
                nonce.as_ptr(),
                nonce.len(),
                data.as_ptr(),
                data.len(),
                ad.as_ptr(),
                ad.len(),
            ))?;
        }
        Ok(())
    }
}
