sql_to_schema.py 2.16 KB
import pandas as pd
import numpy as np
import sqlparse


def filter_whitespace(tokenlist):
    tokenlistfilt = []
    for token in tokenlist:
        if (not(token.ttype == sqlparse.tokens.Whitespace) or (token.ttype == sqlparse.tokens.Whitespace.Newline)):
            tokenlistfilt.append(token)
    return tokenlistfilt


def extract_tables(sql):
    tables = []
    sqlp = sqlparse.parse(sql)
    tks = sqlp[0].tokens
    tksf = filter_whitespace(tks)

    cursor = 0; chunk = tksf[cursor]
    state = "start"

    while(state != 'done'):
        if state == "start":
            assert chunk.value == 'SELECT'
            cursor += 1; chunk = tksf[cursor]
        
            assert type(chunk) in [sqlparse.sql.IdentifierList, sqlparse.sql.Identifier]
            cursor += 1; chunk = tksf[cursor]
        
            assert chunk.value == 'FROM'
            cursor += 1; chunk = tksf[cursor]
            state = "afterFROM"

        if state == "afterFROM":
            assert type(chunk) == sqlparse.sql.Identifier
            tables.append(chunk)
            cursor += 1; 
            if cursor >= len(tksf):
                state = "done"
            else:
                state = "afterIdentifier"
                chunk = tksf[cursor]

        if state == "afterIdentifier":
            chunk = tksf[cursor]
            assert chunk.value in ["INNER JOIN", "LEFT OUTER JOIN"]
            cursor += 1; chunk = tksf[cursor]
            assert type(chunk) == sqlparse.sql.Identifier
            tables.append(chunk)
            cursor += 1; chunk = tksf[cursor]
            assert chunk.value == 'ON'
            cursor += 1; chunk = tksf[cursor]
            assert type(chunk) == sqlparse.sql.Comparison
            cursor += 1;
            if cursor >= len(tksf):
                state = "done"
            else:
                state = "afterFROM"
    tables2 = []
    for table in tables:
        tsplit = table.value.split(' ')
        if len(tsplit) == 2:
            tables2.append({'name':tsplit[0],'alias':tsplit[1]})
        else:
            tables2.append({'name':tsplit[0]})            
            
    return tables2
    


#%% for experimenting

if __name__ == '__main__':
    pass