/*
 * This software is Copyright (c) 2013 Lukas Odzioba <ukasz at openwall dot net>
 * and it is hereby released to the general public under the following terms:
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted.
 */
#ifdef HAVE_OPENCL

#if FMT_EXTERNS_H
extern struct fmt_main fmt_opencl_pbkdf2_hmac_sha256;
#elif FMT_REGISTERS_H
john_register_one(&fmt_opencl_pbkdf2_hmac_sha256);
#else

#include <ctype.h>
#include <string.h>
#include <assert.h>
#include "misc.h"
#include "arch.h"
#include "base64_convert.h"
#include "common.h"
#include "formats.h"
#include "options.h"
#include "common-opencl.h"

#define FORMAT_LABEL		"PBKDF2-HMAC-SHA256-opencl"
#define FORMAT_NAME		""
#define ALGORITHM_NAME		"PBKDF2-SHA256 OpenCL"

#define BENCHMARK_COMMENT	", rounds=12000"
#define BENCHMARK_LENGTH	-1

#define BINARY_ALIGN		4
#define SALT_ALIGN		1

#define uint8_t			unsigned char
#define uint32_t		unsigned int

#define PLAINTEXT_LENGTH	55
#define SALT_LENGTH		50
#define BINARY_SIZE		32
#define SALT_SIZE		sizeof(salt_t)

#define FMT_PREFIX		"$pbkdf2-sha256$"
#define FMT_CISCO8		"$8$"
#define KERNEL_NAME		"pbkdf2_sha256_kernel"
#define SPLIT_KERNEL_NAME	"pbkdf2_sha256_loop"

#define MIN(a, b)		(((a) < (b)) ? (a) : (b))
#define MAX(a, b)		(((a) > (b)) ? (a) : (b))
#define HASH_LOOPS		(13*71) // factors 13, 13, 71
#define ITERATIONS		12000

typedef struct {
	uint8_t length;
	uint8_t v[PLAINTEXT_LENGTH];
} pass_t;

typedef struct {
	uint32_t hash[8];
} crack_t;

typedef struct {
	uint8_t length;
	uint8_t salt[64];
	uint32_t rounds;
} salt_t;

typedef struct {
	uint32_t ipad[8];
	uint32_t opad[8];
	uint32_t hash[8];
	uint32_t W[8];
	uint32_t rounds;
} state_t;

/*
	Testcases generated by passlib, format: $pbkdf2-256$rounds$salt$checksum
	salt and checksum are encoded in "adapted base64"
*/
static struct fmt_tests tests[] = {

	{"$pbkdf2-sha256$12000$2NtbSwkhRChF6D3nvJfSGg$OEWLc4keep8Vx3S/WnXgsfalb9q0RQdS1s05LfalSG4", ""},
	{"$pbkdf2-sha256$12000$fK8VAoDQuvees5ayVkpp7Q$xfzKAoBR/Iaa68tjn.O8KfGxV.zdidcqEeDoTFvDz2A", "1"},
	{"$pbkdf2-sha256$12000$GoMQYsxZ6/0fo5QyhtAaAw$xQ9L6toKn0q245SIZKoYjCu/Fy15hwGme9.08hBde1w", "12"},
	{"$pbkdf2-sha256$12000$6r3XWgvh/D/HeA/hXAshJA$11YY39OaSkJuwb.ONKVy5ebCZ00i5f8Qpcgwfe3d5kY", "123"},
	{"$pbkdf2-sha256$12000$09q711rLmbMWYgwBIGRMqQ$kHdAHlnQ1i1FHKBCPLV0sA20ai2xtYA1Ev8ODfIkiQg", "1234"},
	{"$pbkdf2-sha256$12000$Nebce08pJcT43zuHUMo5Rw$bMW/EsVqy8tMaDecFwuZNEPVfQbXBclwN78okLrxJoA", "openwall"},
	{"$pbkdf2-sha256$12000$mtP6/39PSQlhzBmDsJZS6g$zUXxf/9XBGrkedXVwhpC9wLLwwKSvHX39QRz7MeojYE", "password"},
	{"$pbkdf2-sha256$12000$35tzjhGi9J5TSilF6L0XAg$MiJA1gPN1nkuaKPVzSJMUL7ucH4bWIQetzX/JrXRYpw", "pbkdf2-sha256"},
	{"$pbkdf2-sha256$12000$sxbCeE8pxVjL2ds7hxBizA$uIiwKdo9DbPiiaLi1y3Ljv.r9G1tzxLRdlkD1uIOwKM", " 15 characters "},
	{"$pbkdf2-sha256$12000$CUGI8V7rHeP8nzMmhJDyXg$qjq3rBcsUgahqSO/W4B1bvsuWnrmmC4IW8WKMc5bKYE", " 16 characters__"},
	{"$pbkdf2-sha256$12000$FmIM4VxLaY1xLuWc8z6n1A$OVe6U1d5dJzYFKlJsZrW1NzUrfgiTpb9R5cAfn96WCk", " 20 characters______"},
	{"$pbkdf2-sha256$12000$fA8BAMAY41wrRQihdO4dow$I9BSCuV6UjG55LktTKbV.bIXtyqKKNvT3uL7JQwMLp8", " 24 characters______1234"},
	{"$pbkdf2-sha256$12000$/j8npJTSOmdMKcWYszYGgA$PbhiSNRzrELfAavXEsLI1FfitlVjv9NIB.jU1HHRdC8", " 28 characters______12345678"},
	{"$pbkdf2-sha256$12000$xfj/f6/1PkcIoXROCeE8Bw$ci.FEcPOKKKhX5b3JwzSDo6TGuYjgj1jKfCTZ9UpDM0", " 32 characters______123456789012"},
	{"$pbkdf2-sha256$12000$6f3fW8tZq7WWUmptzfmfEw$GDm/yhq1TnNR1MVGy73UngeOg9QJ7DtW4BnmV2F065s", " 40 characters______12345678901234567890"},
	{"$pbkdf2-sha256$12000$dU5p7T2ndM7535tzjpGyVg$ILbppLkipmonlfH1I2W3/vFMyr2xvCI8QhksH8DWn/M", " 55 characters______________________________________end"},
	{"$pbkdf2-sha256$12000$iDFmDCHE2FtrDaGUEmKMEaL0Xqv1/t/b.x.DcC6lFEI$tUdEcw3csCnsfiYbFdXH6nvbftH8rzvBDl1nABeN0nE", "salt length = 32"},
	{"$pbkdf2-sha256$12000$0zoHwNgbIwSAkDImZGwNQUjpHcNYa43xPqd0DuH8H0OIUWqttfY.h5DynvPeG.O8N.Y$.XK4LNIeewI7w9QF5g9p5/NOYMYrApW03bcv/MaD6YQ", "salt length = 50"},
	{"$pbkdf2-sha256$12000$HGPMeS9lTAkhROhd653Tuvc.ZyxFSOk9x5gTYgyBEAIAgND6PwfAmA$WdCipc7O/9tTgbpZvcz.mAkIDkdrebVKBUgGbncvoNw", "salt length = 40"},
	{"$pbkdf2-sha256$12001$ay2F0No7p1QKgVAqpbQ2hg$UbKdswiLpjc5wT8Zl2M6VlE2cNiKuhAUntGciP8JjPw", "test"},
	// cisco type 8 hashes.  20k iterations, different base-64 (same as WPA).  Also salt is used RAW, it is not base64 decoded prior to usage
	{"$8$dsYGNam3K1SIJO$7nv/35M/qr6t.dVc7UY9zrJDWRVqncHub1PE9UlMQFs", "cisco"},
	{"$8$6NHinlEjiwvb5J$RjC.H.ydVb34wDLqJvfjyG1ubxYKpfXqv.Ry9mtrNBY", "password"},
	{"$8$lGO8juTOQLPCHw$cBv2WEaFCLUA24Z48CKUGixIywyGFP78r/slQcMXr3M", "JtR"},
	{NULL}
};

//#define DEBUG
static pass_t *host_pass;			      /** plain ciphertexts **/
static salt_t *host_salt;			      /** salt **/
static crack_t *host_crack;			      /** hash**/
static cl_int cl_error;
static cl_mem mem_in, mem_out, mem_salt, mem_state;
static cl_kernel split_kernel;

#define STEP			0
#define SEED			1024
#define OCL_CONFIG		"pbkdf2-hmac-sha256"

static const char * warn[] = {
        "xfer: ",  ", init: " , ", crypt: ", ", res xfer: "
};

static int split_events[] = { 2, -1, -1 };

static int crypt_all(int *pcount, struct db_salt *_salt);
static int crypt_all_benchmark(int *pcount, struct db_salt *_salt);

// This file contains auto-tuning routine(s). Has to be included after formats definitions.
#include "opencl-autotune.h"
#include "memdbg.h"

static void create_clobj(size_t kpc, struct fmt_main *self)
{
#define CL_RO CL_MEM_READ_ONLY
#define CL_WO CL_MEM_WRITE_ONLY
#define CL_RW CL_MEM_READ_WRITE

#define CLCREATEBUFFER(_flags, _size, _string)\
	clCreateBuffer(context[gpu_id], _flags, _size, NULL, &cl_error);\
	HANDLE_CLERROR(cl_error, _string);

#define CLKERNELARG(kernel, id, arg, msg)\
	HANDLE_CLERROR(clSetKernelArg(kernel, id, sizeof(arg), &arg), msg);

	host_pass = mem_calloc(kpc * sizeof(pass_t));
	host_crack = mem_calloc(kpc * sizeof(crack_t));
	host_salt = mem_calloc(sizeof(salt_t));

	mem_in = CLCREATEBUFFER(CL_RO, kpc * sizeof(pass_t),
	                        "Cannot allocate mem in");
	mem_salt = CLCREATEBUFFER(CL_RO, sizeof(salt_t),
	                          "Cannot allocate mem salt");
	mem_out = CLCREATEBUFFER(CL_WO, kpc * sizeof(crack_t),
	                         "Cannot allocate mem out");
	mem_state = CLCREATEBUFFER(CL_RW, kpc * sizeof(state_t),
	                           "Cannot allocate mem state");

	CLKERNELARG(crypt_kernel, 0, mem_in, "Error while setting mem_in");
	CLKERNELARG(crypt_kernel, 1, mem_salt, "Error while setting mem_salt");
	CLKERNELARG(crypt_kernel, 2, mem_state, "Error while setting mem_state");

	CLKERNELARG(split_kernel, 0, mem_state, "Error while setting mem_state");
	CLKERNELARG(split_kernel, 1 ,mem_out, "Error while setting mem_out");
}

/* ------- Helper functions ------- */
static size_t get_task_max_work_group_size()
{
	size_t s;

	s = autotune_get_task_max_work_group_size(FALSE, 0, crypt_kernel);
	s = MIN(s, autotune_get_task_max_work_group_size(FALSE, 0, split_kernel));
	return s;
}

static size_t get_task_max_size()
{
	return 0;
}

static size_t get_default_workgroup()
{
	if (cpu(device_info[gpu_id]))
		return 1;
	else
		return 128;
}

static void release_clobj(void)
{
	HANDLE_CLERROR(clReleaseMemObject(mem_in), "Release mem in");
	HANDLE_CLERROR(clReleaseMemObject(mem_salt), "Release mem salt");
	HANDLE_CLERROR(clReleaseMemObject(mem_out), "Release mem out");
	HANDLE_CLERROR(clReleaseMemObject(mem_state), "Release mem state");

	MEM_FREE(host_pass);
	MEM_FREE(host_salt);
	MEM_FREE(host_crack);
}

static void init(struct fmt_main *self)
{
	char build_opts[64];

        snprintf(build_opts, sizeof(build_opts),
                 "-DHASH_LOOPS=%u -DPLAINTEXT_LENGTH=%u",
                 HASH_LOOPS, PLAINTEXT_LENGTH);
        opencl_init("$JOHN/kernels/pbkdf2_hmac_sha256_kernel.cl",
            gpu_id, build_opts);

	crypt_kernel =
	    clCreateKernel(program[gpu_id], KERNEL_NAME, &cl_error);
	HANDLE_CLERROR(cl_error, "Error creating crypt kernel");

	split_kernel =
	    clCreateKernel(program[gpu_id], SPLIT_KERNEL_NAME, &cl_error);
	HANDLE_CLERROR(cl_error, "Error creating split kernel");

	// Initialize openCL tuning (library) for this format.
	opencl_init_auto_setup(SEED, HASH_LOOPS, split_events,
		warn, 2, self, create_clobj, release_clobj,
		sizeof(state_t), 0);

	// Auto tune execution from shared/included code.
	self->methods.crypt_all = crypt_all_benchmark;
	autotune_run(self, ITERATIONS, 0,
	             (cpu(device_info[gpu_id]) ? 1000000000 : 10000000000ULL));
	self->methods.crypt_all = crypt_all;
}

static void done(void)
{
	release_clobj();
	HANDLE_CLERROR(clReleaseKernel(crypt_kernel), "Release kernel 1");
	HANDLE_CLERROR(clReleaseKernel(split_kernel), "Release kernel 2");
	HANDLE_CLERROR(clReleaseProgram(program[gpu_id]),
	    "Release Program");
}

static char *prepare(char *fields[10], struct fmt_main *self)
{
	static char Buf[120];
	char tmp[43+3], *cp;

	if (strncmp(fields[1], FMT_CISCO8, 3) != 0)
		return fields[1];
	if (strlen(fields[1]) != 4+14+43)
		return fields[1];
	sprintf (Buf, "%s20000$%14.14s$%s", FMT_PREFIX, &(fields[1][3]),
		base64_convert_cp(&(fields[1][3+14+1]), e_b64_crypt, 43, tmp, e_b64_mime, sizeof(tmp), flg_Base64_NO_FLAGS));
	cp = strchr(Buf, '+');
	while (cp) {
		*cp = '.';
		cp = strchr(cp, '+');
	}
	return Buf;
}

static int valid(char *ciphertext, struct fmt_main *pFmt)
{
	int saltlen = 0;
	char *p, *c = ciphertext;

	if (strncmp(ciphertext, FMT_CISCO8, 3) == 0) {
		char *f[10];
		f[1] = ciphertext;
		ciphertext = prepare(f, pFmt);
	}
	if (strncmp(ciphertext, FMT_PREFIX, strlen(FMT_PREFIX)) != 0)
		return 0;
	if (strlen(ciphertext) < 44 + strlen(FMT_PREFIX))
		return 0;
	c += strlen(FMT_PREFIX);
	if (strtol(c, NULL, 10) == 0)
		return 0;
	c = strchr(c, '$');
	if (c == NULL)
		return 0;
	c++;
	p = strchr(c, '$');
	if (p == NULL)
		return 0;
	saltlen = base64_valid_length(c, e_b64_mime, flg_Base64_MIME_PLUS_TO_DOT);
	c += saltlen;
	saltlen = B64_TO_RAW_LEN(saltlen);
	if (saltlen > SALT_LENGTH)
		return 0;
	if (*c != '$') return 0;
	c++;
	if (base64_valid_length(c, e_b64_mime, flg_Base64_MIME_PLUS_TO_DOT) != 43)
		return 0;
	return 1;
}

static void *get_salt(char *ciphertext)
{
	static salt_t salt;
	char *p, *c = ciphertext;

	memset(&salt, 0, sizeof(salt));
	c += strlen(FMT_PREFIX);
	salt.rounds = strtol(c, NULL, 10);
	c = strchr(c, '$') + 1;
	p = strchr(c, '$');
	if (p-c==14 && salt.rounds==20000) {
		// for now, assume this is a cisco8 hash
		strnzcpy((char*)(salt.salt), c, 15);
		salt.length = 14;
		return (void*)&salt;
	}
	salt.length = base64_convert(c, e_b64_mime, p-c, salt.salt, e_b64_raw, sizeof(salt.salt), flg_Base64_MIME_PLUS_TO_DOT);
	return (void *)&salt;
}

static void *binary(char *ciphertext)
{
	static char ret[256 / 8 + 1];
	char *c = ciphertext;

	c += strlen(FMT_PREFIX) + 1;
	c = strchr(c, '$') + 1;
	c = strchr(c, '$') + 1;
#ifdef DEBUG
	assert(strlen(c) == 43);
#endif
	base64_convert(c, e_b64_mime, 43, ret, e_b64_raw, sizeof(ret), flg_Base64_MIME_PLUS_TO_DOT);
	return ret;
}

static void set_salt(void *salt)
{
	memcpy(host_salt, salt, SALT_SIZE);
	HANDLE_CLERROR(clEnqueueWriteBuffer(queue[gpu_id], mem_salt,
		CL_FALSE, 0, sizeof(salt_t), host_salt, 0, NULL, NULL),
	    "Copy salt to gpu");
}

static void opencl_limit_gws(int count)
{
	global_work_size =
	    (count + local_work_size - 1) / local_work_size * local_work_size;
}

static int crypt_all_benchmark(int *pcount, struct db_salt *salt)
{
	int count = *pcount;
	size_t gws;
	size_t *lws = local_work_size ? &local_work_size : NULL;

	gws = GET_MULTIPLE_OR_BIGGER(count, local_work_size);

#if 0
	printf("crypt_all(%d)\n", count);
	printf("LWS = %d, GWS = %d\n", (int)local_work_size, (int)gws);
#endif

	// Copy data to gpu
	BENCH_CLERROR(clEnqueueWriteBuffer(queue[gpu_id], mem_in,
		CL_FALSE, 0, gws * sizeof(pass_t), host_pass, 0,
		NULL, multi_profilingEvent[0]), "Copy data to gpu");

	// Run 1st kernel
	BENCH_CLERROR(clEnqueueNDRangeKernel(queue[gpu_id], crypt_kernel,
		1, NULL, &gws, lws, 0, NULL,
		multi_profilingEvent[1]), "Run kernel");

	// Warm-up run
	BENCH_CLERROR(clEnqueueNDRangeKernel(queue[gpu_id], split_kernel,
		1, NULL, &gws, lws, 0, NULL,
		NULL), "Run kernel");
	BENCH_CLERROR(clEnqueueNDRangeKernel(queue[gpu_id], split_kernel,
		1, NULL, &gws, lws, 0, NULL,
		multi_profilingEvent[2]), "Run split kernel");

	// Read the result back
	BENCH_CLERROR(clEnqueueReadBuffer(queue[gpu_id], mem_out,
		CL_TRUE, 0, gws * sizeof(crack_t), host_crack, 0,
		NULL, multi_profilingEvent[3]), "Copy result back");

	return count;
}

static int crypt_all(int *pcount, struct db_salt *salt)
{
	int i, count = *pcount;
	int loops = (host_salt->rounds + HASH_LOOPS - 1) / HASH_LOOPS;

	opencl_limit_gws(count);

#if 0
	printf("crypt_all(%d)\n", count);
	printf("LWS = %d, GWS = %d\n", (int)local_work_size, (int)global_work_size);
#endif

	// Copy data to gpu
	BENCH_CLERROR(clEnqueueWriteBuffer(queue[gpu_id], mem_in,
		CL_FALSE, 0, global_work_size * sizeof(pass_t), host_pass, 0,
		NULL, NULL), "Copy data to gpu");

	// Run kernel
	HANDLE_CLERROR(clEnqueueNDRangeKernel(queue[gpu_id], crypt_kernel,
		1, NULL, &global_work_size, &local_work_size, 0, NULL,
		profilingEvent), "Run kernel");

	for(i = 0; i < loops; i++) {
		HANDLE_CLERROR(clEnqueueNDRangeKernel(queue[gpu_id], split_kernel,
			1, NULL, &global_work_size, &local_work_size, 0, NULL,
			profilingEvent), "Run split kernel");
		HANDLE_CLERROR(clFinish(queue[gpu_id]), "clFinish");
		opencl_process_event();
	}
	// Read the result back
	HANDLE_CLERROR(clEnqueueReadBuffer(queue[gpu_id], mem_out,
		CL_TRUE, 0, global_work_size * sizeof(crack_t), host_crack, 0,
		NULL, NULL), "Copy result back");

	return count;
}

static int cmp_all(void *binary, int count)
{
	int i;
	for (i = 0; i < count; i++)
		if (host_crack[i].hash[0] == ((uint32_t *) binary)[0])
			return 1;
	return 0;
}

static int cmp_one(void *binary, int index)
{
	int i;
	for (i = 0; i < 8; i++)
		if (host_crack[index].hash[i] != ((uint32_t *) binary)[i])
			return 0;
	return 1;
}

static int cmp_exact(char *source, int index)
{
	return 1;
}

static void set_key(char *key, int index)
{
	int saved_key_length = MIN(strlen(key), PLAINTEXT_LENGTH);

	memcpy(host_pass[index].v, key, saved_key_length);
	host_pass[index].length = saved_key_length;
}

static char *get_key(int index)
{
	static char ret[PLAINTEXT_LENGTH + 1];
	memcpy(ret, host_pass[index].v, PLAINTEXT_LENGTH);
	ret[host_pass[index].length] = 0;
	return ret;
}

static int binary_hash_0(void *binary)
{
#if 0
	uint32_t i, *b = binary;
	puts("binary");
	for (i = 0; i < 8; i++)
		printf("%08x ", b[i]);
	puts("");
#endif
	return (((uint32_t *) binary)[0] & 0xf);
}

static int get_hash_0(int index)
{
#if 0
	uint32_t i;
	puts("get_hash");
	for (i = 0; i < 8; i++)
		printf("%08x ", ((uint32_t *) host_crack[index].hash)[i]);
	puts("");
#endif
	return host_crack[index].hash[0] & 0xf;
}

static int get_hash_1(int index)
{
	return host_crack[index].hash[0] & 0xff;
}

static int get_hash_2(int index)
{
	return host_crack[index].hash[0] & 0xfff;
}

static int get_hash_3(int index)
{
	return host_crack[index].hash[0] & 0xffff;
}

static int get_hash_4(int index)
{
	return host_crack[index].hash[0] & 0xfffff;
}

static int get_hash_5(int index)
{
	return host_crack[index].hash[0] & 0xffffff;
}

static int get_hash_6(int index)
{
	return host_crack[index].hash[0] & 0x7ffffff;
}

#if FMT_MAIN_VERSION > 11
static unsigned int iteration_count(void *salt)
{
	salt_t *my_salt;

	my_salt = salt;
	return (unsigned int)my_salt->rounds;
}
#endif

struct fmt_main fmt_opencl_pbkdf2_hmac_sha256 = {
{
	FORMAT_LABEL,
	FORMAT_NAME,
	ALGORITHM_NAME,
	BENCHMARK_COMMENT,
	BENCHMARK_LENGTH,
	PLAINTEXT_LENGTH,
	BINARY_SIZE,
	BINARY_ALIGN,
	SALT_SIZE,
	SALT_ALIGN,
	1,
	1,
	FMT_CASE | FMT_8_BIT,
#if FMT_MAIN_VERSION > 11
		{
			"iteration count",
		},
#endif
	tests
}, {
	init,
	done,
	fmt_default_reset,
	prepare,
	valid,
	fmt_default_split,
	binary,
	get_salt,
#if FMT_MAIN_VERSION > 11
		{
			iteration_count,
		},
#endif
	fmt_default_source,
	{
		binary_hash_0,
		fmt_default_binary_hash_1,
		fmt_default_binary_hash_2,
		fmt_default_binary_hash_3,
		fmt_default_binary_hash_4,
		fmt_default_binary_hash_5,
		fmt_default_binary_hash_6
	},
	fmt_default_salt_hash,
	set_salt,
	set_key,
	get_key,
	fmt_default_clear_keys,
	crypt_all,
	{
		get_hash_0,
		get_hash_1,
		get_hash_2,
		get_hash_3,
		get_hash_4,
		get_hash_5,
		get_hash_6
	},
	cmp_all,
	cmp_one,
	cmp_exact
}};

#endif /* plugin stanza */

#endif /* HAVE_OPENCL */
