import click
import numpy as np
import pandas as pd
from sklearn.impute import KNNImputer

from prebio_df import read_prebio
from nan_analysis import missing_values_table


def impute_missing(target: pd.Series, set: pd.DataFrame) -> pd.DataFrame:
    """Impute missing values in the dataset using KNN imputation.
    
    Args:
        target (pd.Series): Target row with missing values to be imputed.
        set (pd.DataFrame): Dataset containing the target row and other rows for imputation.
    
    Returns:
        pd.DataFrame: Dataset with missing values imputed using KNN imputation.
    """
    set.loc[len(set)] = target
    imputer = KNNImputer()
    res = imputer.fit_transform(set)
    
    return pd.DataFrame(res, columns = set.columns)
    

def check_complete_samples(df: pd.DataFrame) -> tuple:
    """Find all samples with 0% of missing values
    
    Args:
        df (pd.DataFrame): Input dataframe
    
    Returns:
        tuple: A tuple containing the indexes of samples with 0% missing values (empty if none),
            an array of NaN counts for each row, and the unique counts of NaNs.
    """
    nans = np.array([r.isnull().sum() for _, r in df.iterrows()])
    indexes = np.where(nans == 0)[0]
    print(f'> The amount of instances with 0% of missing values is {len(indexes)}!')
    unique, counts = np.unique(nans, return_counts=True)
    print(dict(zip(unique, counts)))
    
    return indexes, nans, unique


@click.command
@click.option('--ds', default='../data/polycarbonates_entregable.xlsx', help='Path to prebio2 dataset (.xlsx)')
def main(ds):
    """Main function to preprocess the dataset by imputing missing values and storing the result.
    
    Args:
        ds (str): Path to the prebio2 dataset (.xlsx).
    """
    df = read_prebio(ds)
    missing_table = missing_values_table(df)
    # print(missing_table)
    
    # We drop columns with > 70% of data missing.
    more_than_70_labels = missing_table[missing_table['% of Total Values'] > 70].index
    df70 = df.drop(columns = more_than_70_labels)
    missing_values_table(df70, title="------ Columns dropped ------")
    
    # Get all the samples with a 0% of missing values
    indexes, nans, unique = check_complete_samples(df70)
    full_df = df70.iloc[indexes]
    
    # For ease imputing new data we delete categorical columns (None of them contains NaNs)
    full_df = full_df.drop(columns=['Epoxide', 'SMILES', 'Catalyst'])

    # Assert the new dataset is empty
    assert missing_values_table(full_df).empty, "This shouldn't be happening. Your full_df has empty values!"

    # Impute missing data sequentially
    for min_nan in np.sort(unique)[1:]:  # We omit taking samples with 0% of missing values as they will be repeated
        for index in np.where(nans == min_nan)[0]:
            full_df = impute_missing(df70.iloc[index], full_df)

    df70.update(full_df)
    missing_values_table(df70, title='------ Imputed DataFrame ------')

    path = '../data/imputed_polycarbonates.xlsx'
    print(f'The new dataset will be stored in {path}')
    df70.to_excel(path)


if __name__ == "__main__":
    main()