import model
import torch
import torch.nn as nn
import sys
import subprocess
import os
import magic
import argparse
import mimetypes

"""  
Argument parser
"""

parser = argparse.ArgumentParser(description='run the model on a folder (or a single file) to generate outputs')

parser.add_argument('model_path', metavar='model_path', type=str, help='path to the .model file - e.g. saved_models/test.model')

parser.add_argument('test_folder_path', metavar='test_folder_path', type=str, help='path to the folder containing new files e.g. test_folder_path.py')

parser.add_argument('--cutoff', dest='cutoff', type=float, default=0.7, help='optional cutoff value for classifying a file as encrypted (default: 0.7)')

args = parser.parse_args()

"""
Some static values for files to be excluded, and colors for printing

safe_types: files that will not be marked as encrypted regardless of entropy values

bcolors: static color variables for printing
"""

safe_types = ['application/zip', 
              'application/x-gzip', 
              'application/x-bzip2',
              'image/png',
              'image/jpeg',
              'video/quicktime',
              'image/png',
              'image/svg',
              'application/x-rar',
              'application/java-archive']
              
class bcolors:
     HEADER = '\033[95m'
     OKBLUE = '\033[94m'
     OKGREEN = '\033[92m'
     WARNING = '\033[93m'
     FAIL = '\033[91m'
     ENDC = '\033[0m'
     BOLD = '\033[1m'
     UNDERLINE = '\033[4m'

"""
Here, we load the model and set up some static variables and counting variables that will be used to keep track of the new files we will read. 
"""
     
# Load the model
loaded_model = torch.load(args.model_path)

# Get mean and standard deviation (saved during training) 
feature_means = loaded_model.get_means()
feature_stds = loaded_model.get_stds()

# Set model to eval mode
loaded_model.eval()

# Variables to keep count - used for printing later
totalcount = 0
encryptedcount = 0

# array that will store generated features for the new testing set
features_array = []

# maps containing index to name (names_indices) and index to MIME type (mime_indices) mappings
names_indices = dict()
mime_indices = dict()

# index to iterate over
index = 0

# keep track of number of files that could not be read - used for printing later
exceptions = 0

"""
The following loop reads files in the folder specified in the test_folder_path argument. Each file generates a list of features, a mime type, and a name (just the name of the file being read).
"""
# Open the file with read only permit
#f = open('C:\\Users\\Parablu\Documents\\rw detection\\serversetup\\source_oct\\my_text_file.txt')
f = open(args.test_folder_path)
# use readline() to read the first line
Lines = f.readlines()
f.close()
for line in Lines: 
     print(type(line))
     print((line.strip()))
     if (line.strip()) != "0,File-bytes,Entropy,Chi-square,Mean,Monte-Carlo-Pi,Serial-Correlation":
          print("pvn...............")
          print(features_array)
          captured_csv = line.split('\r\n')[0]
          print(line)
          print(captured_csv)
          file_name = captured_csv.split('::')[0]
          print("file_name")
          print(file_name)
          captured_csv = captured_csv.split('::')[1]
          print(type(captured_csv))
          captured_cleaned = captured_csv.strip().split(',')
          print(captured_cleaned)
          captured_array = captured_cleaned[1:]
          print(captured_array)
          try:
               captured_array = list(map(float,captured_array))
          except ValueError:
                  print('File '  + ' could not be read'+file_name)
                  exceptions = exceptions + 1
          features_array.append(captured_array)
          print('features_array')
          print(features_array)
          names_indices[index] = file_name
          #mime = magic.Magic(mime=True)
          #mime_type = mime.from_file("filename.txt")
          mime_type, encoding = mimetypes.guess_type(file_name)
          mime_indices[index] = mime_type
          index = index + 1
        
        
    
"""
Simple helper function that converts (based on a  cutoff argument) any Nx2 matrix containing probabilities of encrypted vs not-encrypted items into an Nx1 array that contains a prediction of whether that entity is encrypted(1 for encrypted, 0 for not encrypted).
"""
def mark_encrypted(probss, cutoff=0.7):
    encrypted = []*probss.size()[0]
    for probs in probss:
        if probs[0] > cutoff:
            encrypted +=[0]
        else:
            encrypted +=[1]
    return encrypted
    
"""
This is where the main predictions are run using the loaded model - all the pre-processing above is used to generate the features_array which is passed to the loaded model. 
"""
# convert raw features generated in the loop above into a pytorch tensor 
test_features = torch.tensor(features_array)

# normalize the features and run the model on the features
values = loaded_model(((test_features - feature_means)/feature_stds).float())

cutoff = args.cutoff

# convert output raw values into probabilities
probs = nn.functional.softmax(values, dim=1)
# use mark_encrypted function to convert probabilities into predictions
enc = mark_encrypted(probs, cutoff)


"""
Variables to keep track of various metrics
"""

# files marked encrypted
encrypted = 0

# files marked not encrypted directly by the model
unencrypted_direct = 0

# files marked not encrypted by inspecting the MIME type
unencrypted_inspection = 0
    
"""
This loop iterates through each index and - 
1. If prediction is 1 (model says the file is encrypted) - it does a check on the MIME type of the file and if it is one of the "safe_types" then marks it as unencrypted_inspection.
2. If prediction is 1, but MIME type is not one of the "safe_types" then marks file as encrypted
3. If prediction is 0 (model says the fiel is not encrypted) - it marks the file as unencrytped_direct
"""
for index in names_indices:
    prediction = enc[index]
    name = names_indices[index]
    type = mime_indices[index]
    if prediction == 1: 
        print(probs[index])
        if type in safe_types:
            print('File ' + name + ' is actually not encrypted - it is either compressed or an image/video file.')
            print("result::"+name+ '::not encrypted')
            unencrypted_inspection += 1
        else:
            print("result::"+name+ '::encrypted')
            encrypted += 1
    else:
        print("result::"+name+ '::not encrypted')
        unencrypted_direct += 1

        
"""
Printing
"""
print("")
print("")
print("Total files scanned = {}".format(len(names_indices)+exceptions))
print("Files encrypted = {}".format(encrypted))
print("Files unencrypted (direct) = {}".format(unencrypted_direct))
print("Files unencrypted (upon inspection) = {}".format(unencrypted_inspection))
print("Files that could not be read = {}".format(exceptions))



