import unittest, datetime
from pyorq import *

class test_create_ptype(unittest.TestCase):
    def tearDown(self):
        try:
            db.drop_table('A')
        except:
            pass

    def test_create_empty_class(self):
        self.failUnlessRaises(DBError,
                              ptype.ptype,
                              'A', (pobject,), {'database':db})

    def test_create_class_pint(self):
        class A(pobject):
            database = db
            a = pint()

        self.failUnless(isinstance(A.tid, (int, long, str)))

    def test_create_class_pfloat(self):
        class A(pobject):
            database = db
            a = pfloat()

        self.failUnless(isinstance(A.tid, (int, long, str)))

    def test_create_class_pstr(self):
        class A(pobject):
            database = db
            a = pstr()

        self.failUnless(isinstance(A.tid, (int, long, str)))

    def test_create_class_pdate(self):
        class A(pobject):
            database = db
            a = pdate()

        self.failUnless(isinstance(A.tid, (int, long, str)))
    
    def test_create_class_ptime(self):
        class A(pobject):
            database = db
            a = ptime()

        self.failUnless(isinstance(A.tid, (int, long, str)))

    def test_create_class_pdatetime(self):
        class A(pobject):
            database = db
            a = pdatetime()

        self.failUnless(isinstance(A.tid, (int, long, str)))

class test_create_schema(unittest.TestCase):
    def tearDown(self):
        for table in ['A', 'B']:
            try:
                db.drop_table(table)
            except:
                pass

    def test_create_schema(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A)
            b = pint()
        self.failUnless(isinstance(A.tid, (int, long, str)))

        
class test_create_instance(unittest.TestCase):
    def tearDown(self):
        try:
            db.drop_table('A')
        except:
            pass

    def test_new_instance(self):
        class A(pobject):
            database = db
            a = pint()
        a = A()
        self.failUnless(isinstance(a, A))
        self.failUnless(a.oid is None)
        b = A()
        self.failUnless(b.oid is None)

    def test_oid_of_commited_instance(self):
        class A(pobject):
            database = db
            a = pint()
        a = A()
        a.commit()
        self.failUnless(isinstance(a, A))
        self.failIf(a.oid is None)
        b = A()
        b.commit()
        self.failIf(b.oid is None)
        self.failIfEqual(a.oid, b.oid)

    def test_get_cached__instance(self):
        class A(pobject):
            database = db
            a = pint()
        a = A()
        a.commit()
        b = A(oid=a.oid)
        self.failUnless(a is b)

    def test_update(self):
        class A(pobject):
            database = db
            a = pint()
        a = A()
        a.a = 1
        a.commit()
        a.a = 2
        a.commit()
        
class test_complex_instances(unittest.TestCase):
    def tearDown(self):
        for table in ['A', 'B']:
            try:
                db.drop_table(table)
            except:
                pass

    def test_create_instances(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A)
            b = pint()

        b = B()
        b.a = A()
        b.b = 2
        b.a.a = 1

        self.failUnlessEqual(b.a.a, 1)

    def test_oid_of_commited_instance(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A)
            b = pint()

        b = B()
        b.a = A()
        b.commit()
        self.failIf(b.oid is None)
        self.failIf(b.a.oid is None)

    def test_oid_of_implicit_instance(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A, (A, None))
            b = pint()

        b = B()
        b.commit()
        self.failUnless(isinstance(b.a, A))
        self.failIf(b.oid is None)
        self.failIf(b.a.oid is None)

    def test_commit_with_ref_to_None(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A)
            b = pint()

        b = B()
        b.commit()
        self.failIf(b.oid is None)
        self.failIf(b.a is not None)

    def test_commit_with_all_NULLS(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A)

        b = B()
        b.commit()
        self.failIf(b.oid is None)
        self.failIf(b.a is not None)

    def test_update(self):
        class A(pobject):
            database = db
            a = pint()
        class B(pobject):
            database = db
            a = pref(A, (A, None))
            b = pint()

        b = B()
        b.commit()
        b.b = 2
        b.commit()
        self.failIf(b.oid is None)
        self.failIf(b.a.oid is None)

class test_retrieve(unittest.TestCase):
    def tearDown(self):
        try:
            db.drop_table('A')
        except:
            pass

    def test_simple_int(self):
        class A(pobject):
            database = db
            a = pint()
            b = pint()
        a = A()
        a.a, a.b = 3, 5
        a.commit()
        oid = a.oid
        del a
        A.database.clear_cache()
        a = A(oid=oid)
        self.failUnlessEqual(a.a, 3)
        self.failUnlessEqual(a.b, 5)
        
    def test_simple_float(self):
        class A(pobject):
            database = db
            a = pfloat()
            b = pfloat()
        a = A()
        a.a, a.b = 3, 3.3
        a.commit()
        oid = a.oid
        del a
        A.database.clear_cache()
        a = A(oid=oid)
        self.failUnlessEqual(a.a, 3)
        self.failUnlessEqual(a.b, 3.3)

    def test_simple_str(self):
        class A(pobject):
            database = db
            a = pstr()
            b = pstr()
        a = A()
        a.a, a.b = 'a', "a'b"
        a.commit()
        oid = a.oid
        del a
        A.database.clear_cache()
        a = A(oid=oid)
        self.failUnlessEqual(a.a, 'a')
        self.failUnlessEqual(a.b, "a'b")
            
    def test_simple_date(self):
        class A(pobject):
            database = db
            a = pdate()
            b = pdate()
        a = A()
        a.a, a.b = datetime.date(1999, 1, 1), datetime.date(1999, 2, 2)
        a.commit()
        oid = a.oid
        del a
        A.database.clear_cache()
        a = A(oid=oid)
        self.failUnlessEqual(a.a, datetime.date(1999, 1, 1))
        self.failUnlessEqual(a.b, datetime.date(1999, 2, 2))

    def test_simple_time(self):
        class A(pobject):
            database = db
            a = ptime()
            b = ptime()
        a = A()
        a.a, a.b = datetime.time(0, 0, 0), datetime.time(1, 1, 1)
        a.commit()
        oid = a.oid
        del a
        A.database.clear_cache()
        a = A(oid=oid)
        self.failUnlessEqual(a.a, datetime.time(0, 0, 0))
        self.failUnlessEqual(a.b, datetime.time(1, 1, 1))

class test_complex_retrieve(unittest.TestCase):
    def tearDown(self):
        for table in ['A', 'B', 'C']:
            try:
                db.drop_table(table)
            except:
                pass

    def test_retrieve_ref(self):
        class A(pobject):
            database = db
            a = pint()

        class B(pobject):
            database = db
            b = pref(A)

        a = A()
        a.a = 2
        b = B()
        b.b = a
        b.commit()
        oid = b.oid
        del a, b
        A.database.clear_cache()
        b = B(oid=oid)
        self.failUnlessEqual(b.b.a, 2)
        
    def test_retrieve_NULL_ref(self):
        class A(pobject):
            database = db
            a = pint()

        class B(pobject):
            database = db
            b = pref(A)

        b = B()
        b.commit()
        oid = b.oid
        del b
        A.database.clear_cache()
        b = B(oid=oid)
        self.failUnlessEqual(b.b, None)
            
    def test_retrieve_derived(self):
        class A(pobject):
            database = db
            a = pint()

        class B(A):
            b = pint()
            
        class C(pobject):
            database = db
            c = pref(A)

        b = B()
        b.b = 3
        c = C()
        c.c = b
        c.commit()
        oid = c.oid
        del b, c
        A.database.clear_cache()
        c = C(oid=oid)
        self.failUnless(isinstance(c.c, B))
        self.failUnlessEqual(c.c.b, 3)
        

suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(test_create_ptype))
suite.addTest(unittest.makeSuite(test_create_schema))
suite.addTest(unittest.makeSuite(test_create_instance))
suite.addTest(unittest.makeSuite(test_complex_instances))
suite.addTest(unittest.makeSuite(test_retrieve))
suite.addTest(unittest.makeSuite(test_complex_retrieve))

def test_db(external_lib_name, interface_name, **kwargs):
    global db

    # Try to import the external libray
    try:
        mod = __import__(external_lib_name)
    except ImportError:
        print "%s not available" % external_lib_name
        return

    try:
        mod = __import__('pyorq.interface.'+interface_name)
        mod = getattr(mod.interface, interface_name)
        db = getattr(mod, interface_name)(**kwargs)
    except:
        print "Unable to instantiate interface %s" % interface_name
        return

    print "Running with %s" % interface_name
    unittest.TextTestRunner(verbosity=1).run(suite)

if __name__ == '__main__':
    test_db("sys", "nodb")
    test_db("pyPgSQL.libpq", "postgresql_db", database="testdb")
    test_db("_sqlite", "sqlite_db", database="testdb")
    test_db("_mysql", "mysql_db", db="testdb")
