import argparse
import os
import time

import numpy as np
from scipy import io
def normalized_max(data):
    data = data / np.max(data)
    return data
def normalized_std(data):
    data = data / np.std(data, ddof=1)
    return data

st = time.time()
parse = argparse.ArgumentParser()
'/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/data/test_data'
parse.add_argument('--read_path', '-dt', type=str,
                   default='/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/dataset_mat/normalized',
                   help='1: normalized_data OR 0:non_normalized_data')
parse.add_argument('--train', type=str,
                   default='train_index.txt',
                   help='1: normalized_data OR 0:non_normalized_data')
parse.add_argument('--test', type=str,
                   default='test_index.txt',
                   help='1: normalized_data OR 0:non_normalized_data')
parse.add_argument('--save_path', '-dp', type=str,
                   default='/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/partition_data/')
parse.add_argument('--input_size', '-size', type=int, default=1)
parse.add_argument('--remove', '-r', type=int, default=0)

# params
args = parse.parse_args()
read_path = args.read_path
save_path = args.save_path
remove_flag = args.remove
input_size = args.input_size
train = args.train
test = args.test
train_list = open(train).readlines()
test_list = open(test).readlines()

print(input_size)
if remove_flag:
    nz = 896
    inputs_folder_path = os.path.join(read_path, 'remove')
else:
    nz = 1024
    inputs_folder_path = os.path.join(read_path, 'non_remove')

print(nz)
real_flag = 0
# Read inputs
print('Inputs PATH: ', inputs_folder_path)

reshape_size = int(nz * 96 / input_size)
t = 0
# Generate Train
train_inputs = []
train_MV = []
train_name = []
print('********** Generate Train Dataset **********')
for train_path in train_list:

    train_path = train_path.replace('\n', '')
    train_name.append(train_path.split('/')[-1][:-4])
    print(train_path)
    data = io.loadmat(train_path)

    if real_flag:
        input_data_temp = np.hstack([np.real(data['input_data']), np.imag(data['input_data'])])
    else:
        input_data_temp = data['input_data']
    MV_temp = data['MV']

    x, t, y = input_data_temp.shape

    # 3-D array reshape (num,25,input_size)
    reshape_inputs = np.zeros((reshape_size, t, input_size), dtype=np.complex64)
    for r_i in range(reshape_size):
        for r_j in range(input_size):
            index = r_i * input_size + r_j
            reshape_inputs[r_i, :, r_j] = input_data_temp[int(index / y), :, index % y]
    train_inputs.append(reshape_inputs)

    # MV
    reshape_MV = np.zeros((reshape_size, input_size))
    for r_i in range(reshape_size):
        for r_j in range(input_size):
            index = r_i * input_size + r_j
            reshape_MV[r_i, r_j] = MV_temp[int(index / y), index % y]
    train_MV.append(reshape_MV)

train_inputs = np.array(train_inputs)
train_MV = np.array(train_MV)
train_inputs = train_inputs.reshape((-1, t, input_size))
train_labels = train_MV.reshape((-1, input_size))


# Generate Test
print('********** Generate Test Dataset **********')
test_inputs = []
test_MV = []
test_name = []
for test_path in test_list:
    test_path = test_path.replace('\n', '')
    print(test_path)
    test_name.append(test_path.split('/')[-1][:-4])
    data = io.loadmat(test_path)
    if real_flag:
        input_data_temp = np.hstack([np.real(data['input_data']), np.imag(data['input_data'])])
    else:
        input_data_temp = data['input_data']
    MV_temp = data['MV']
    
    # normalized
    #input_data_temp = normalized(normalized)
    #MV_temp = normalized(MV_temp)

    x, t, y = input_data_temp.shape

    # 3-D array reshape (num,25,input_size)
    reshape_inputs = np.zeros((reshape_size, t, input_size), dtype=np.complex64)
    for r_i in range(reshape_size):
        for r_j in range(input_size):
            # axial
            index = r_i * input_size + r_j
            # lateral 
            # index = r_i
            reshape_inputs[r_i, :, r_j] = input_data_temp[int(index / y), :, index % y]
    test_inputs.append(reshape_inputs)

    # MV
    reshape_MV = np.zeros((reshape_size, input_size))
    for r_i in range(reshape_size):
        for r_j in range(input_size):
            index = r_i * input_size + r_j
            reshape_MV[r_i, r_j] = MV_temp[int(index / y), index % y]
    test_MV.append(reshape_MV)

inputs = np.array(test_inputs)
outputs = np.array(test_MV)
test_inputs = inputs
test_labels = outputs
print(train_inputs.shape)
print(test_inputs.shape)

# Save partition data.
io.savemat(os.path.join(save_path, 'complex_train_test_normalized_is_%d.mat'% input_size),
           {'train_inputs': train_inputs, 'train_labels': train_labels,
            'test_inputs': test_inputs, 'test_labels': test_labels, 'test_name': test_name})


# Save partition data.
io.savemat(os.path.join(save_path, 'complex_test_normalized_is_%d.mat'% input_size),
           {'train_inputs': train_inputs, 'train_labels': train_labels,
            'test_inputs': test_inputs, 'test_labels': test_labels, 'test_name': test_name})
print('Total times', time.time() - st)
