from pathlib import Path
import numpy as np
import torch
import model
import torch.optim as optim
import torch.nn as nn
from utils import batch_iter
import argparse

"""
Argument parser
"""
parser = argparse.ArgumentParser(description='Read files containing ransomware and regular files and train pytorch model')
parser.add_argument('regular_file', metavar='reg', type=str, help='path to file with regular files\' features - e.g. saved_features/regular.txt')
parser.add_argument('rw_file', metavar='rw', type=str, help='path to file with ransomware encrypted files\' features - eg. saved_features/ransomware.txt')
parser.add_argument('model_pth', metavar='model_pth', type=str, help='path and name to save trained model - e.g. saved_models/output_21_10_2020')
parser.add_argument('--split', dest='train_test_split', default=0.7, help='percentage of files to use for training (default: 0.7)')
parser.add_argument('--hidden_size', dest='hidden_size', default=256, help='number of neurons in the hidden layer (default: 256)')
parser.add_argument('--eval', dest='eval', default=False, action='store_true', help='run quick eval after training')

args = parser.parse_args()

""" 
Set up the feature matrices - regular and ransomware
"""
regular_mat = np.zeros((0,7))
rw_mat = np.zeros((0,7))

"""
This loop reads the file containing features of regular (non-encrypted) data and appends 0 to the matrices - representing the label "not-encrypted"
"""
with open(args.regular_file) as file:
    for line in file.readlines():
        if line[0] == '1':
            arr = line.strip().split(",")
            arr = (list(map(float, arr))[1:])
            arr.append(0)
            regular_mat = np.vstack([regular_mat,arr])
    print(regular_mat.shape)
    print(np.mean(regular_mat,axis=0))
    
"""
This loop reads the file containing features of ransomware (encrypted) data and appends 1 to the matrices - representing the label "encrypted"
"""
with open(args.rw_file) as file:
    for line in file.readlines():
        # this filters for lines that contain actual features, the rest of the lines have filler information
        if line[0] == '1':
            arr = line.strip().split(",")
            arr = (list(map(float, arr))[1:])
            arr.append(1)
            rw_mat = np.vstack([rw_mat,arr])
    print(rw_mat.shape)
    print(np.mean(rw_mat, axis=0))

"""
The following code combines the two matrices generated above and randomly splits it into training and test data.
"""
comb_mat = np.vstack([regular_mat,rw_mat])
num_samples = comb_mat.shape[0]
train_size = int(num_samples*float(args.train_test_split))
test_size = num_samples - train_size
train_data = np.random.permutation(comb_mat)[:train_size]
test_data = np.random.permutation(comb_mat)[train_size:]
print('train data shape = {}'.format(train_data.shape))
print('test data shape = {}'.format(test_data.shape))

# convert training set from numpy to pytorch
data_tensor = torch.from_numpy(train_data)

# split training data from 7 columns into labels (the last column) and the training data itself (the first 6)
labels_tensor = data_tensor[:,6]    
train_tensor = data_tensor[:,:6]  

# convert test set from numpy to pytorch
test_data_tensor = torch.from_numpy(test_data)
        
# split testing data from 7 columns into labels (the last column) and the test data itself (the first 6)
test_tensor = test_data_tensor[:,:6]
test_labels_tensor = test_data_tensor[:,6]   

"""
 Print examples and shapes - can safely remove these
"""
print('train data eg = {}'.format(train_tensor[:15]))
print('label data eg = {}'.format(labels_tensor[:15]))

print(train_tensor.shape)
print(train_tensor.dtype)

print(test_tensor.shape)

"""
Initialize the model and optimizers - prepare for training
"""
model = model.NNModel(6,2,int(args.hidden_size))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

feature_sums = torch.tensor(6)
feature_vars = torch.tensor(6)
num = 0

"""
Train for 5000 epochs (number of times we go through the training set)
"""
for epoch in range(5000):
    cumulative_loss = 0.0
    # use the batch_iter utility function to sample from a batch of data - sampling 256 samples at a time
    for input_means, input_stds, inputs, labels in batch_iter(train_tensor, labels_tensor, 256):
        feature_sums = feature_sums + input_means
        feature_vars = feature_vars + torch.pow(input_stds,2)
        num = num+1
        optimizer.zero_grad()
        outputs = model(inputs.float())
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()
        cumulative_loss += loss.item()
    # print the cumulative loss at the end of each epoch (multiply by sample size) to maintain consistency
    print("Cumulative loss at epoch {} = {}".format(epoch, cumulative_loss*256))

"""
Store the means and standard deviations to adjust  at the feature of new data at runtime. 
"""
feature_means = feature_sums/num
feature_stds = torch.pow(feature_vars/num, 0.5)

# print for information - can safely remove these
print('Feature means = {}'.format(feature_means))
print('Feature stds = {}'.format(feature_stds))
print('Iterations = {}'.format(num))

"""
Save the model and the parameters at the path specified in arguments
"""
model.set_means(feature_means)
model.set_stds(feature_stds)

path = args.model_pth
# save full model
torch.save(model, path+'.model')
# save model params
torch.save(model, path+'.params')

"""
If eval is set to true - use the testing data to print performance on the test set
"""
if args.eval:
    model.eval()

    test_tensor = (test_tensor - feature_means)/(feature_stds+0.000001)

    test_outputs = model(test_tensor.float(),testing=True)
    test_amax = torch.argmax(test_outputs, dim=1)
    print(test_amax[:20])
    print(test_labels_tensor[:20])
    print(test_amax == test_labels_tensor)