import os
import struct
import time
import numpy as np
import argparse
from scipy import io


def reshape_data(data, nz):
    temp = np.zeros((nz, 25, 96), dtype=np.complex64)
    # sequence converts to 3-D array
    for s_i in range(96):
        for s_j in range(nz):
        
            temp[s_j, :, s_i] = data[:, s_i * nz + s_j].T
    return temp


def normalized(data):
    # normalized
    '''
    real_ = data.real
    imag_ = data.imag
    # print(np.mean(real_), np.std(real_, ddof=1), np.mean(imag_), np.std(imag_, ddof=1))

    real_ = (real_ - np.mean(real_)) / np.std(real_, ddof=1)
    imag_ = (imag_ - np.mean(imag_)) / np.std(imag_, ddof=1)
    data = real_ + 1j * imag_
    '''
    
    data = data / np.std(data, ddof=1)
    return data


st = time.time()
parse = argparse.ArgumentParser()
parse.add_argument('--data_path', '-dt', type=str,
                   default='/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/data/data/teacher_data1.bin')
#'/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/data/data/teacher_data_carotid_long.bin'
parse.add_argument('--save_path', '-s', type=str,
                   default='/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/Sample_100_complex')
'/home/lzy/UltrasonicPrediction/Dataset/ultrasonic_data/train_test/original'
args = parse.parse_args()
file_path = args.data_path
save_path = args.save_path


f = open(file_path, 'rb')
nx = struct.unpack('H', f.read(2))[0]
nz = struct.unpack('H', f.read(2))[0]
ndat = struct.unpack('H', f.read(2))[0]
nframe = struct.unpack('H', f.read(2))[0]
print(nx, nz, ndat, nframe)

input_data = np.zeros([ndat, nx * nz], dtype=np.complex)
teacher_data = np.zeros([ndat, nx * nz], dtype=np.complex)

for ifrm in range(nframe):
    print(ifrm)
    index = 0
    for i in range(nx):
        for j in range(nz):
            # input data
            temp = np.asarray(struct.unpack('d' * ndat, f.read(8 * ndat)), dtype=np.float)
            temp1 = np.asarray(struct.unpack('d' * ndat, f.read(8 * ndat)), dtype=np.float)
            input_data[:, index] = temp + 1j * temp1

            # teacher data
            temp = np.asarray(struct.unpack('d' * ndat, f.read(8 * ndat)), dtype=np.float)
            temp1 = np.asarray(struct.unpack('d' * ndat, f.read(8 * ndat)), dtype=np.float)
            teacher_data[:, index] = temp + 1j * temp1
            index = index + 1

    # Normalized data
    input_data = normalized(input_data)

    # Calculation DAS 1024 * 96
    DAS = np.zeros((nz, nx))
    for das_i in range(nx):
        for das_j in range(nz):
            DAS[das_j, das_i] = np.abs(np.mean(input_data[:, nz * das_i + das_j]))

    # Calculation MV 1024 * 96
    MV = np.zeros((nz, nx))
    for mv_i in range(nx):
        for mv_j in range(nz):
            MV[mv_j, mv_i] = np.abs(np.vdot(input_data[:, nz * mv_i + mv_j], teacher_data[:, nz * mv_i + mv_j]))

    # reshape 1024 * 25 * 96
    reshape_input_data = reshape_data(input_data, nz)
    reshape_teacher_data = reshape_data(teacher_data, nz)

    # Save data non remove
    io.savemat(os.path.join(save_path, 'data' + str(ifrm) + '.mat'),
               {'input_data': reshape_input_data, 'teacher_data': reshape_teacher_data,
                'MV': MV, 'DAS': DAS})


    # Save data remove
    '''
    io.savemat(os.path.join(save_path, 'remove', 'data' + str(ifrm) + '.mat'),
               {'input_data': reshape_input_data[128:, :, :], 'teacher_data': reshape_teacher_data[128:, :, :],
                'MV': MV[128:, :], 'DAS': DAS[128:, :]})
    '''
print('Total time:', time.time() - st)
