import pandas as pd
import scipy.stats as stats
from itertools import combinations
import matplotlib.pyplot as plt
import seaborn as sns

L_support={}
def apriori(transactions, min_support):
    item_set = {}
    for transaction in transactions:
        for item in transaction:
            item_set[item] = item_set.get(item, 0) + 1

    n_transactions = len(transactions)
    L = []
    support_data = {}
    L1 = []
    for item, count in item_set.items():
        support = count / n_transactions
        if support >= min_support:
            L1.append(frozenset([item]))
            support_data[frozenset([item])] = support
        L_support[item] = support
    L.append(L1)

    k = 2
    while True:
        Ck = []
        Lk_minus_1 = L[k-2]
        for i in range(len(Lk_minus_1)):
            for j in range(i+1, len(Lk_minus_1)):
                l1 = list(Lk_minus_1[i])
                l2 = list(Lk_minus_1[j])
                union = frozenset(l1) | frozenset(l2)
                if len(union) == k and union not in Ck:
                    Ck.append(union)
        item_count = {c: 0 for c in Ck}
        for transaction in transactions:
            tset = frozenset(transaction)
            for c in Ck:
                if c.issubset(tset):
                    item_count[c] += 1
        Lk = []
        for c, count in item_count.items():
            support = count / n_transactions
            if support >= min_support:
                Lk.append(c)
                support_data[c] = support
        if not Lk:
            break
        L.append(Lk)
        k += 1

    freq_itemsets = []
    for level in L:
        for itemset in level:
            freq_itemsets.append((set(itemset), support_data[itemset]))

    return freq_itemsets

def calculate_imbalance_ratio(frequent_patterns):
    max_support = max(frequent_patterns, key=lambda x: x[1])[1]
    min_support = min(frequent_patterns, key=lambda x: x[1])[1]
    imbalance_ratio = max_support / min_support
    return imbalance_ratio

def chi_square_test(frequent_patterns, transactions):
    chi_squared_results = []
    n = len(transactions)
    trans_sets = [frozenset(t) for t in transactions]
    for itemset, _ in frequent_patterns:
        if len(itemset) == 2:
            item1, item2 = list(itemset)
            count_A_and_B = sum(1 for t in trans_sets if item1 in t and item2 in t)
            count_A_only = sum(1 for t in trans_sets if item1 in t and item2 not in t)
            count_B_only = sum(1 for t in trans_sets if item1 not in t and item2 in t)
            count_neither = n - count_A_and_B - count_A_only - count_B_only
            observed = [count_A_and_B, count_A_only, count_B_only, count_neither]
            support_A = sum(1 for t in trans_sets if item1 in t) / n
            support_B = sum(1 for t in trans_sets if item2 in t) / n
            expected_A_and_B = n * support_A * support_B
            expected_A_only = n * support_A * (1 - support_B)
            expected_B_only = n * (1 - support_A) * support_B
            expected_neither = n * (1 - support_A) * (1 - support_B)
            expected = [expected_A_and_B, expected_A_only, expected_B_only, expected_neither]
            chi2, p_value_two_itemse = stats.chisquare(f_obs=observed, f_exp=expected)
            chi_squared_results.append((itemset, chi2, p_value_two_itemse))
    return chi_squared_results

def calculate_lift(frequent_patterns, transactions):
    lift_results = []
    for itemset, _ in frequent_patterns:
        if len(itemset) == 2:
            itemset_count = sum(1 for transaction in transactions if itemset.issubset(frozenset(transaction))) / len(transactions)
            item1, item2 = list(itemset)
            support_A = L_support[item1]
            support_B = L_support[item2]
            if support_A > 0 and support_B > 0:
                itemset_lift = itemset_count / (support_A * support_B)
            else:
                itemset_lift = 0
            lift_results.append((itemset, itemset_lift))
    return lift_results

def main():
    df = pd.read_csv("./data_file/student_habits_performance.csv")
    id_col = 'student_id'
    numeric_cols = ['age', 'study_hours_per_day', 'social_media_hours', 'netflix_hours',
                    'attendance_percentage', 'sleep_hours', 'exercise_frequency',
                    'mental_health_rating', 'exam_score']
    categorical_cols = ['gender', 'part_time_job', 'diet_quality',
                        'parental_education_level', 'internet_quality',
                        'extracurricular_participation']
    for col in numeric_cols:
        median = df[col].median()
        df[col + '_bin'] = df[col].apply(lambda x: f"{col}_high" if x > median else f"{col}_low")
    transactions = []
    for _, row in df.iterrows():
        items = []
        for col in numeric_cols:
            items.append(row[col + '_bin'])
        for col in categorical_cols:
            items.append(f"{col}_{row[col]}")
        transactions.append(items)
    min_support = 0.25
    frequent_patterns = apriori(transactions, min_support)
    imbalance_ratio = calculate_imbalance_ratio(frequent_patterns)
    print(f"\n{'='*10} Imbalance Ratio {'='*10}")
    print(f"Imbalance Ratio: {imbalance_ratio:.4f}")

    chi_squared_results = chi_square_test(frequent_patterns, transactions)
    unpacked_results = []
    for itemset, chi2, p_value in chi_squared_results:
        if len(itemset) == 2:
            item1, item2 = list(itemset)
            unpacked_results.append((item1, item2, chi2, p_value))

    df1 = pd.DataFrame(unpacked_results, columns=['item1', 'item2', 'chi_square', 'p_value'])
    df1 = df1.sort_values(by='chi_square', ascending=False).reset_index(drop=True)

    print(f"\n{'='*10} Chi-Square Test Results {'='*10}")
    print(df1.to_string(index=False, formatters={
        'chi_square': '{:.4f}'.format,
        'p_value': '{:.4e}'.format
    }))

    chi_square_matrix = pd.pivot_table(df1, values='chi_square', index='item1', columns='item2')
    plt.figure(figsize=(10, 8))
    sns.heatmap(chi_square_matrix, annot=True, cmap='coolwarm', fmt=".2f", cbar=True)
    plt.title('Chi-Square Test Results Heatmap (Itemset Relationships)')
    plt.show()

    lift_results = calculate_lift(frequent_patterns, transactions)
    df_lift = pd.DataFrame(lift_results, columns=['itemset', 'lift'])
    df_lift = df_lift.sort_values(by='lift', ascending=False).reset_index(drop=True)

    print(f"\n{'='*10} Lift Results {'='*10}")
    print(df_lift.to_string(index=False, formatters={
        'lift': '{:.4f}'.format
    }))

    result_df = pd.DataFrame(frequent_patterns, columns=['itemset', 'support'])
    result_df = result_df.sort_values(by='support', ascending=False).reset_index(drop=True)

    print(f"\n{'='*10} Frequent Itemsets {'='*10}")
    print(result_df.to_string(index=False, formatters={
        'support': '{:.4f}'.format
    }))
if __name__ == "__main__":
    main()
