# -*- coding: utf-8 -*-
import os

from tensorflow_core.python.keras.datasets import cifar10
from tensorflow_core.python.keras.engine.input_layer import Input
from tensorflow_core.python.keras.layers.convolutional import Conv2D
from tensorflow_core.python.keras.layers.core import Dense, Dropout, Activation, Flatten
from tensorflow_core.python.keras.layers.normalization import BatchNormalization
from tensorflow_core.python.keras.layers.pooling import MaxPooling2D
from tensorflow_core.python.keras.models import Sequential, Model, load_model
from tensorflow_core.python.keras.optimizer_v2.adam import Adam
import tensorflow as tf

from train_and_evaluation import evaluate_regression_model, train_model
from datasets import load_house_dataset_data, DatasetType
import matplotlib.pyplot as plt

__author__ = 106360

def generate_simple_cnn_regression_model(input_shape,n_blocks=3,weights='',is_regression=True,num_classes=1,freeze=False,remove_head=False):
    # define the model input
    inputs = Input(shape=(input_shape,input_shape,3))
    # loop over the number of filters
    x = inputs
    for n in range(n_blocks):
        x = Conv2D(32, (3, 3), padding="same",name='conv_%d' % n)(x)
        x = Activation("relu")(x)
        # x = BatchNormalization()(x)
        x = MaxPooling2D()(x)

    x= Flatten()(x)
    x = Dense(16, activation="relu", name='pre_last_dense_reg')(x)
    if not remove_head:
        if is_regression:
            y = Dense(num_classes, activation="sigmoid",name='last_dense_reg')(x)
        else:
            y = Dense(num_classes, activation="softmax", name='last_dense_clf')(x)
    else:
        y=x
    model = Model(inputs,y)

    if weights!='':
        model.load_weights(weights,by_name=True)

    if freeze:
        for layer in model.layers[:-2]:
            layer.trainable = False

    return model

def train_cifar100(num_classes=100,batch_size=32):
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    y_train = tf.keras.utils.to_categorical(y_train, num_classes)
    y_test = tf.keras.utils.to_categorical(y_test, num_classes)

    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255

    print('Using real-time data augmentation.')
    # This will do preprocessing and realtime data augmentation:
    datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by std of the dataset
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        zca_epsilon=1e-06,  # epsilon for ZCA whitening
        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
        # randomly shift images horizontally (fraction of total width)
        width_shift_range=0.1,
        # randomly shift images vertically (fraction of total height)
        height_shift_range=0.1,
        shear_range=0.,  # set range for random shear
        zoom_range=0.,  # set range for random zoom
        channel_shift_range=0.,  # set range for random channel shifts
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        cval=0.,  # value used for fill_mode = "constant"
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False,  # randomly flip images
        # set rescaling factor (applied before any other transformation)
        rescale=None,
        # set function that will be applied on each input
        preprocessing_function=None,
        # image data format, either "channels_first" or "channels_last"
        data_format=None,
        # fraction of images reserved for validation (strictly between 0 and 1)
        validation_split=0.0)

    # Compute quantities required for feature-wise normalization
    # (std, mean, and principal components if ZCA whitening is applied).
    datagen.fit(x_train)

    dataset_train = datagen.flow(x_train, y_train, batch_size=batch_size)
    dataset_test = datagen.flow(x_test, y_test, batch_size=batch_size)
    num_steps_train = x_train.shape[0] // batch_size
    num_steps_test = y_train.shape[0] // batch_size




    model = generate_simple_cnn_regression_model(32, is_regression=False, num_classes=100)
    opt = Adam(lr=1e-3, decay=1e-3 / 200)
    model.compile(loss='categorical_crossentropy',
                  metrics=['categorical_crossentropy', 'accuracy'],
                  optimizer=opt)
    model.summary()


    history = model.fit_generator(dataset_train, steps_per_epoch=num_steps_train,
                        epochs=20,
                        validation_data=(x_test, y_test),
                        workers=1)
    return model


if __name__ == "__main__":
    pre_train_with_cifar100 = True
    weights = ''
    if pre_train_with_cifar100:
        file_weight_cifar100 = 'pretrained_cifar100.h5'

        try:
            model = load_model(file_weight_cifar100)
        except:
            model = train_cifar100()
            model.save(file_weight_cifar100)
        weights = file_weight_cifar100


    (trainX_data,trainX_img, trainY, testX_data,testX_img,testY), normalizer = load_house_dataset_data(test_size=0.2,random_state=666,type=DatasetType.Both)


    trainX = trainX_img['bathroom_img']
    testX = testX_img['bathroom_img']
    input_shape = trainX.shape[1]

    if pre_train_with_cifar100:
        file_weight_finetune = 'regression_model_image_finetune.h5'
        model = generate_simple_cnn_regression_model(input_shape,weights=weights,freeze=True)
        opt = Adam(lr=1e-3, decay=1e-3 / 200)
        model.compile(loss='mean_squared_error',metrics=['mean_absolute_percentage_error','mean_absolute_error','mean_squared_error'], optimizer=opt)
        model.summary()
        model = train_model(trainX, trainY, testX, testY,model,show_plot=True,epochs=500,batch_size=32)
        evaluate_regression_model(model,testX,testY,normalizer,show_plot=True)
        model.save(file_weight_finetune)
        weights = file_weight_finetune
        final_model_weight = 'regression_model_image_pretrained.h5'
    else:
        final_model_weight = 'regression_model_image_from_scratch.h5'


    model = generate_simple_cnn_regression_model(input_shape, weights=weights)
    opt = Adam(lr=1e-3, decay=1e-3 / 200)
    model.compile(loss='mean_squared_error',
                  metrics=['mean_absolute_percentage_error', 'mean_absolute_error', 'mean_squared_error'],
                  optimizer=opt)
    model.summary()
    model = train_model(trainX, trainY, testX, testY, model, show_plot=True, epochs=500, batch_size=32)
    evaluate_regression_model(model, testX, testY, normalizer, show_plot=True)
    model.save(final_model_weight)