"""
This file contains functions which will extract features from the report and from the sample.
"""

import pefile
from os import remove, rename, mkdir
from shutil import copy
import pyzipper
from py7zr import SevenZipFile
from conf import FAMILIES

static_features = ["bytes", "apis", "entropy"]

#extract sample from 7z and zip archives, returns sample's name
def extractArchive(source, ext=""):
    sname = ""
    if not pyzipper.is_zipfile(source):
        sname = source.split("/")[-1].split("_")[-2]
        copy(source, sname)
        return sname
    #extract 7z files
    if source[-2:] == "7z":
        arch = SevenZipFile(source, password="infected")
        sname = arch.list()[0]
        arch.extractall()
        arch.close()
    #extract zip files from MB
    elif source.split("/")[4][0] == "_" and source[-3:] == "zip":
        with pyzipper.AESZipFile(source) as zf:
            zf.pwd = b"infected"
            sname = zf.filelist[0].filename
            zf.extractall(".")
    #extract 7z files from VS
    elif source[-3:] == "zip":
        with pyzipper.PyZipFile(source) as zf:
            zf.pwd = b"infected"
            sname = zf.filelist[0].filename
            zf.extractall(".")
    elif source[-3:] == "apk":
        sname = source.split("/")[-1]
        copy(source, sname)
    else:
        print("Archive format for " + source + " not recognized")
    if sname != "" and len(sname.split(".")) == 1:
        rename(sname, sname + "." + ext)
        sname = sname + "." + ext
    return sname

"""
extracts imported APIs, source must be an exe or the static section of the cuckoo report
"""
def extractStaticAPIs(source, exe=True):
    imports = []
    if exe: #imports are retrieved from the exe
        try:
            pe = pefile.PE(source)
            for entry in pe.DIRECTORY_ENTRY_IMPORT:
                for imp in entry.imports:
                    if imp != None:
                        imports.append(imp.name.decode())
        except pefile.PEFormatError as e:
            print(e)
            return []
        except AttributeError as e:
            print(e)
            return []
    else: #imports retrieved from the report
        for dll in source['pe_imports']:
            for api in dll['imports']:
                imports.append(api['name'])
    return imports

"""
returns entropy of sections, extracted from the report (which must be only the static part)
"""
def extractSectionEntropy(source, exe=True):
    e_d = {}
    if exe:
        try:
            pe = pefile.PE(source, fast_load=True)
            for entry in pe.sections:
                e_d[str(entry.Name.strip(b"\x00"))] = entry.get_entropy()
        except pefile.PEFormatError as e:
            print(e)
            return {}
    else:
        for section in source['pe_sections']:
            e_d[section['name']] = section['entropy']
    return e_d

"""
returns a dictionary containing, for each API used during execution, the number of usages
report must be only the behavior section
"""
def extractDynamicAPIs(report):
    api_d = {}
    for proc in report['apistats']:
        for api in report['apistats'][proc]:
            if api in api_d.keys():
                api_d[api] += report['apistats'][proc][api]
            else:
                api_d[api] = report['apistats'][proc][api]
    return api_d

def extractProcessTree(report):
    return 

"""
This function removes headers from samples to guarantee sterility.
"""
def trimHead(sample):  
    try:
        fn = open(sample + ".bytes", "w") 
        fo = open(sample, "rb") 
        pe = pefile.PE(sample, fast_load=True) 
        #get address of the first section
        addr = pe.sections[0].PointerToRawData 
        if addr == 0:
            handleFail("", fn, fo, sample)
            return ""
        pe.close() 
        fo.seek(addr) 
        writeBytes(fo, fn)
        fn.close() 
        fo.close()
    except pefile.PEFormatError as e:
        handleFail(e, fn, fo, sample)
        return ""
    return sample + ".bytes"

def writeEntropy(hash):
    e_d = extractSectionEntropy(hash)
    if e_d == {}:
     return ""
    filename = hash + ".entropy"
    f = open(filename, "w")
    for section in e_d:
        f.write("Section " + section + " : " + str(e_d[section]) + "\n")
    f.close()
    return filename

def writeStaticAPIs(hash):
    e_l = extractStaticAPIs(hash)
    if e_l == []:
     return ""
    filename = hash + ".apis"
    f = open(filename, "w")
    for api in e_l:
        f.write(api + "\n")
    f.close()
    return filename


def handleFail(e, fn, fo, sample):
    print(e)
    print("Problems with " + sample)
    fn.close()
    fo.close()
    remove(sample + ".bytes")
    return 

def writeBytes(source, dest):
    bytes = source.read()
    i = 0
    for byte in bytes:
        towrite = hex(byte)[2:]
        if len(towrite) == 1:
            towrite = "0" + towrite
        dest.write(towrite + " ")
        i+=1
        if i == 8:
            dest.write("\n")
            i = 0
    return

def prepareStaticW():
    base = "results/"
    print("Creating folder " + base + "...")
    mkdir(base)
    for feat in static_features:
        print("Creating folder " + feat + "...")
        mkdir(base + feat)
        for tag in FAMILIES:
            print("Creating folder " + tag + "...")
            mkdir(base + feat + "/" + tag)
    return