Commit ba1be159 by Selah Clarity

work on bulk insert to replace sloooow pandas.to_sql runs

1 parent a6e71d0e
#GOAL insert a csv into a temp table in Clarity
#This library takes a dataframe and imports it into Clarity
import pandas as pd
import numpy as np
def create_table_sql(table_name, column_def):
create_table_sql = """
DROP TABLE IF EXISTS {table_name};
CREATE TABLE {table_name} ({table_column_def});
""".format(table_name=table_name, table_column_def=column_def)
return create_table_sql
def format_data_for_insert(rows, column_types):
data_formatted = ""
for row in rows:
row_fmt = zip(row, column_types)
items_fmttd = []
for (item, fmt) in row_fmt:
if (fmt == 'STR') | (fmt == 'DT'):
item_fmttd = "'{}'".format(item)
else:
item_fmttd = "{}".format(item)
items_fmttd.append(item_fmttd)
row_formatted = ",".join([str(item) for item in items_fmttd])
row_formatted = "({})".format(row_formatted)
data_formatted += row_formatted + ",\n"
data_formatted = data_formatted[0:-2]
return data_formatted
def generate_insert_sql(table_name, column_names, column_types, data, max_insert=1000):
if type(data) is pd.core.frame.DataFrame:
data = data.values
import math
num_splits = math.ceil(len(data)/max_insert)
data_split = np.array_split(data, num_splits)
column_names_str = ','.join(column_names)
insert_chunk_sql_template = "INSERT INTO {table} ({column_names_str})\nvalues\n{rows_of_data}\n"
for insert_chunk in data_split:
rows_of_data = format_data_for_insert(insert_chunk, column_types)
insert_chunk_sql = insert_chunk_sql_template.format(table = table_name, column_names_str = column_names_str, rows_of_data = rows_of_data )
yield insert_chunk_sql
def collect_insert_sql(table_name, column_names, column_types, data, max_insert=1000):
insert_sql_generator = generate_insert_sql(table_name, column_names, column_types, data, max_insert)
insert_sql = ""
for insert_chunk_sql in insert_sql_generator:
insert_sql += insert_chunk_sql + ";\n"
return insert_sql
import pandas as pd
import numpy as np
import bulk_insert
#%%
table_name = "##COHORT_BULK_INSERT_TEST"
test_data = [
[578,'29389','2011-09-03'],
[332,'11384','2011-09-07'],
[372,'14487','2011-09-07'],
[331,'41384','2011-09-07'],
[931,'24587','2011-10-03']
]
df_test_data = pd.DataFrame(test_data)
#%%
table_column_def = '''
PAT_ID VARCHAR(18) NOT NULL,
MRN VARCHAR(30) NOT NULL,
DELIVERY_DATE DATETIME NOT NULL'''
column_names = ["MRN","PAT_ID","DELIVERY_DATE"]
column_types = ['NUM','STR','DT']
#%%
import bulk_insert
print(bulk_insert.create_table_sql(table_name, table_column_def))
#%%
insert_sql_generator = bulk_insert.generate_insert_sql(table_name, column_names, column_types, df_test_data, max_insert=3)
print(next(insert_sql_generator))
#%%
print(next(insert_sql_generator))
#%%
#Dataframes should never be passed to this function
#print(bulk_insert.format_data_for_insert(df_test_data, ['NUM','STR','DT']))
print(bulk_insert.format_data_for_insert(test_data, ['NUM','STR','DT']))
#%%
print(bulk_insert.collect_insert_sql(table_name, column_names, column_types, df_test_data, max_insert=3))
#%% Test with clarity
import clarity_to_csv as ctc
#conn = ctc.get_clarity_engine().connect()
create_create_table_sql = """
DROP TABLE IF EXISTS {table_name};
CREATE TABLE {table_name} ({table_column_def});
""".format(table_name=table_name, table_column_def=table_column_def)
import bulk_insert
insert_sql_generator = bulk_insert.generate_insert_sql(table_name, column_names, column_types, df_test_data, max_insert=3)
insert_sql1 = next(insert_sql_generator)
insert_sql2 = next(insert_sql_generator)
#%%
conn.execute(create_create_table_sql)
conn.execute(insert_sql1)
conn.execute(insert_sql2)
...@@ -27,59 +27,44 @@ def get_clarity_engine(credsfilename = selahcredsfilename, timeout=600): ...@@ -27,59 +27,44 @@ def get_clarity_engine(credsfilename = selahcredsfilename, timeout=600):
##### BEGIN ACTUAL TESTS ##### ##### BEGIN ACTUAL TESTS #####
#because we dont' want to hit clarity more than necessary, we run tests one at a time
class TestStuff(unittest.TestCase): class TestStuff(unittest.TestCase):
def close_conn(self): #Test a basic connect and execute
def test_basic_conn_execute(self):
eng = get_clarity_engine() eng = get_clarity_engine()
with eng.connect() as conn: with eng.connect() as conn:
res = conn.execute('SELECT TOP 3 PAT_ID FROM PAT_ENC') res = conn.execute('SELECT TOP 3 PAT_ID FROM PAT_ENC')
self.assertEqual(len(list(res)), 3) self.assertEqual(len(list(res)), 3)
# conn.close()
def test_temp_table_after_reconnect(self): def test_temp_table_persistence(self):
eng = get_clarity_engine() eng = get_clarity_engine()
with eng.connect() as conn: with eng.connect() as conn:
conn.execute('DROP TABLE IF EXISTS ##COHORT') conn.execute('DROP TABLE IF EXISTS ##COHORT')
conn.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC') conn.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC')
res = conn.execute('SELECT * FROM ##COHORT') res = conn.execute('SELECT * FROM ##COHORT')
self.assertEqual(len(list(res)), 3) self.assertEqual(len(list(res)), 3)
#we expect the global temp table to disappear with new connection
with eng.connect() as conn: with eng.connect() as conn:
with self.assertRaises(Exception) as e: with self.assertRaises(Exception) as e:
res = conn.execute('SELECT * FROM ##COHORT') res = conn.execute('SELECT * FROM ##COHORT')
# print(e.exception) print(e.exception)
def test_temp_table_both_handles(self): # def test_raw_connection(self):
eng = get_clarity_engine() # eng = get_clarity_engine()
with eng.connect() as conn: # with eng.raw_connection().cursor() as cur:
conn.execute('DROP TABLE IF EXISTS ##COHORT') # cur.execute('DROP TABLE IF EXISTS ##COHORT')
conn.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC') # cur.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC')
res = conn.execute('SELECT * FROM ##COHORT') # cur.execute('SELECT * FROM ##COHORT')
self.assertEqual(len(list(res)), 3) # self.assertEqual(len([row for row in cur]), 3)
with eng.raw_connection().cursor() as cur:
cur.execute('DROP TABLE IF EXISTS ##COHORT')
cur.execute('SELECT TOP 3 PAT_ID INTO ##COHORT FROM PAT_ENC')
cur.execute('SELECT * FROM ##COHORT')
self.assertEqual(len([row for row in cur]), 3)
if __name__ == '__main__': if __name__ == '__main__':
tests_to_run = [
"close_conn" unittest.main()
,"test_temp_table_after_reconnect"
,"test_temp_table_both_handles"
]
suite = unittest.TestSuite()
for test in tests_to_run:
suite.addTest(TestStuff(test))
runner = unittest.TextTestRunner()
runner.run(suite)
# unittest.main()
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!