clarity_tests.py 2.81 KB
import unittest
import sqlalchemy
from unittest.mock import MagicMock

def get_mssql_engine(
    username="lynchse",
    host="clarityprod.uphs.upenn.edu",
    database="clarity_snapshot_db",
    domain="UPHS",
    port="1433",
    timeout=600,
    password=None,
):
    from sqlalchemy import create_engine
    user = domain + "\\" + username
    clarity_engine = create_engine(f"mssql+pymssql://{user}:{password}@{host}:{port}/{database}?timeout={timeout}")
    return clarity_engine


selahcredsfilename = 'C:\\Users\\LynchSe\\Documents\\selah_clarity_credentials.txt'
def get_clarity_engine(credsfilename = selahcredsfilename, timeout=600):
        with open(credsfilename, 'r') as credsfile:
            name = credsfile.readline().strip()
            pw = credsfile.readline().strip()
            clarity_engine = get_mssql_engine(username=name, password=pw, timeout=timeout)
            return clarity_engine


##### BEGIN ACTUAL TESTS #####

class TestStuff(unittest.TestCase):

    def close_conn(self):
        eng = get_clarity_engine()
        with eng.connect() as conn:
            res = conn.execute('SELECT TOP 3 PAT_ID FROM PAT_ENC')
            self.assertEqual(len(list(res)), 3)        
#            conn.close()


    def test_temp_table_after_reconnect(self):    
        eng = get_clarity_engine()
        with eng.connect() as conn:
            conn.execute('DROP TABLE IF EXISTS ##COHORT')
            conn.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC')
            res = conn.execute('SELECT * FROM ##COHORT')
            self.assertEqual(len(list(res)), 3)        
        with eng.connect() as conn:
            with self.assertRaises(Exception) as e:
                res = conn.execute('SELECT * FROM ##COHORT')
#            print(e.exception)


    def test_temp_table_both_handles(self):
        eng = get_clarity_engine()
        with eng.connect() as conn:
            conn.execute('DROP TABLE IF EXISTS ##COHORT')
            conn.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC')
            res = conn.execute('SELECT * FROM ##COHORT')
            self.assertEqual(len(list(res)), 3)        

        with eng.raw_connection().cursor() as cur:
            cur.execute('DROP TABLE IF EXISTS ##COHORT')
            cur.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC')
            cur.execute('SELECT * FROM ##COHORT')
            self.assertEqual(len([row for row in cur]), 3)        


if __name__ == '__main__':
    tests_to_run = [
       "close_conn"
        ,"test_temp_table_after_reconnect"
        ,"test_temp_table_both_handles"
    ]
    suite = unittest.TestSuite()
    for test in tests_to_run:
        suite.addTest(TestStuff(test))
    runner = unittest.TextTestRunner()
    runner.run(suite)


#    unittest.main()