🤬
  • ■ ■ ■ ■ ■ ■
    .gitignore
     1 +logs/*
     2 +KDD_ML_*.py
     3 +*.png
     4 +fit_pckl/
     5 +*.pkl
     6 + 
  • ■ ■ ■ ■ ■ ■
    ids2018_ML_xtd_lstm_sparse.py
     1 +# -*- coding: utf-8 -*-
     2 +"""
     3 +Created on Fri Apr 17 12:57:40 2020
     4 +
     5 +@author: karalis
     6 +"""
     7 +import pandas as pd
     8 +import numpy as np
     9 +import scipy as sp
     10 +import matplotlib
     11 +matplotlib.use('Agg')
     12 +import matplotlib.pyplot as plt
     13 +from sklearn.pipeline import Pipeline
     14 +from sklearn.preprocessing import MinMaxScaler
     15 +from sklearn.preprocessing import StandardScaler
     16 +from sklearn.preprocessing import OrdinalEncoder
     17 +from sklearn.impute import SimpleImputer
     18 +from sklearn.preprocessing import OneHotEncoder
     19 +from sklearn.compose import ColumnTransformer
     20 +from sklearn.model_selection import train_test_split
     21 +
     22 +
     23 +attack_types = {
     24 + 'normal': 'normal',
     25 +
     26 + 'back': 'DoS',
     27 + 'land': 'DoS',
     28 + 'neptune': 'DoS',
     29 + 'pod': 'DoS',
     30 + 'smurf': 'DoS',
     31 + 'teardrop': 'DoS',
     32 + 'mailbomb': 'DoS',
     33 + 'apache2': 'DoS',
     34 + 'processtable': 'DoS',
     35 + 'udpstorm': 'DoS',
     36 +
     37 + 'ipsweep': 'Probe',
     38 + 'nmap': 'Probe',
     39 + 'portsweep': 'Probe',
     40 + 'satan': 'Probe',
     41 + 'mscan': 'Probe',
     42 + 'saint': 'Probe',
     43 +
     44 + 'ftp_write': 'R2L',
     45 + 'guess_passwd': 'R2L',
     46 + 'imap': 'R2L',
     47 + 'multihop': 'R2L',
     48 + 'phf': 'R2L',
     49 + 'spy': 'R2L',
     50 + 'warezclient': 'R2L',
     51 + 'warezmaster': 'R2L',
     52 + 'sendmail': 'R2L',
     53 + 'named': 'R2L',
     54 + 'snmpgetattack': 'R2L',
     55 + 'snmpguess': 'R2L',
     56 + 'xlock': 'R2L',
     57 + 'xsnoop': 'R2L',
     58 + 'worm': 'R2L',
     59 +
     60 + 'buffer_overflow': 'U2R',
     61 + 'loadmodule': 'U2R',
     62 + 'perl': 'U2R',
     63 + 'rootkit': 'U2R',
     64 + 'httptunnel': 'U2R',
     65 + 'ps': 'U2R',
     66 + 'sqlattack': 'U2R',
     67 + 'xterm': 'U2R'
     68 +}
     69 +
     70 +
     71 +is_attack = {
     72 + "DoS":"attack",
     73 + "R2L":"attack",
     74 + "U2R":"attack",
     75 + "Probe":"attack",
     76 + "normal":"normal"
     77 +}
     78 +
     79 +
     80 +ids_path_pkl = "/home/ilias/IDS_2018/Processed Traffic Data for ML Algorithms/"
     81 +ids_path = "/home/ilias/IDS_2018/Processed Traffic Data for ML Algorithms/"
     82 +
     83 +
     84 +class read_data:
     85 + col_names = ["Dst Port" ,"Protocol" ,"Timestamp" ,"Flow Duration" ,"Tot Fwd Pkts" ,"Tot Bwd Pkts" ,
     86 + "TotLen Fwd Pkts" ,"TotLen Bwd Pkts" ,"Fwd Pkt Len Max", "Fwd Pkt Len Min", "Fwd Pkt Len Mean" ,
     87 + "Fwd Pkt Len Std" ,"Bwd Pkt Len Max" ,"Bwd Pkt Len Min" ,"Bwd Pkt Len Mean" ,"Bwd Pkt Len Std" ,
     88 + "Flow Byts/s" ,"Flow Pkts/s" ,"Flow IAT Mean" ,"Flow IAT Std" ,"Flow IAT Max" ,"Flow IAT Min" ,
     89 + "Fwd IAT Tot" ,"Fwd IAT Mean" ,"Fwd IAT Std" ,"Fwd IAT Max" ,"Fwd IAT Min" ,"Bwd IAT Tot" ,
     90 + "Bwd IAT Mean" ,"Bwd IAT Std" ,"Bwd IAT Max" ,"Bwd IAT Min" ,"Fwd PSH Flags" ,"Bwd PSH Flags" ,
     91 + "Fwd URG Flags" ,"Bwd URG Flags" ,"Fwd Header Len" ,"Bwd Header Len" ,"Fwd Pkts/s" ,
     92 + "Bwd Pkts/s" ,"Pkt Len Min" ,"Pkt Len Max" ,"Pkt Len Mean" ,"Pkt Len Std" ,"Pkt Len Var" ,
     93 + "FIN Flag Cnt" ,"SYN Flag Cnt" ,"RST Flag Cnt" ,"PSH Flag Cnt" ,"ACK Flag Cnt" ,"URG Flag Cnt" ,
     94 + "CWE Flag Count" ,"ECE Flag Cnt" ,"Down/Up Ratio" ,"Pkt Size Avg" ,"Fwd Seg Size Avg" ,
     95 + "Bwd Seg Size Avg" ,"Fwd Byts/b Avg" ,"Fwd Pkts/b Avg" ,"Fwd Blk Rate Avg" ,"Bwd Byts/b Avg" ,
     96 + "Bwd Pkts/b Avg" ,"Bwd Blk Rate Avg" ,"Subflow Fwd Pkts" ,"Subflow Fwd Byts" ,
     97 + "Subflow Bwd Pkts" ,"Subflow Bwd Byts" ,"Init Fwd Win Byts" ,"Init Bwd Win Byts" ,
     98 + "Fwd Act Data Pkts" ,"Fwd Seg Size Min" ,"Active Mean" ,"Active Std" ,"Active Max" ,
     99 + "Active Min" ,"Idle Mean" ,"Idle Std" ,"Idle Max" ,"Idle Min" ,"Label"]
     100 +
     101 + IDS2018All = pd.read_csv(ids_path+"Friday-02-03-2018_TrafficForML_CICFlowMeter.csv", dtype = {"Dst Port": str, "Timestamp": str, "": np.float64})
     102 + IDS2018Train, IDS2018Test = train_test_split(IDS2018All, test_size = 0.2, random_state = 42)
     103 + IDS2018All_con = pd.concat([IDS2018Train, IDS2018Test])
     104 +
     105 + IDS2018All_con = IDS2018All.drop("Timestamp", axis = 1)
     106 + print (IDS2018All_con.columns.tolist())
     107 + IDS2018All_con['Flow Duration'] = IDS2018All_con['Flow Duration'].astype(np.int64)
     108 +
     109 +# IDS2018Train.to_csv(ids_path_pkl+"IDS2018Train.csv")
     110 +
     111 + IDS2018Train_len = IDS2018Train.shape[0]
     112 + IDS2018Test_len = IDS2018Test.shape[0]
     113 + print("Train size", IDS2018Train_len)
     114 + print("Test size", IDS2018Test_len)
     115 +
     116 + IDS2018All_is_y = IDS2018All_con["Label"].copy()
     117 + IDS2018All_is_x = IDS2018All_con.drop(["Label"], axis=1)
     118 +
     119 + IDS2018Train_is_y = IDS2018Train["Label"].copy()
     120 + IDS2018Train_is_x = IDS2018Train.drop(["Label"], axis=1)
     121 +
     122 + IDS2018Test_is_y = IDS2018Test["Label"].copy()
     123 + IDS2018Test_is_x = IDS2018Test.drop(["Label"], axis=1)
     124 +
     125 + class_mapping = {'Bot': 0, 'Benign': 1}
     126 + Y_All = IDS2018All_is_y.map(class_mapping)
     127 + Y_Train = IDS2018Train_is_y.map(class_mapping)
     128 + Y_Test = IDS2018Test_is_y.map(class_mapping)
     129 +# Y_Train = np.asarray(Y_T)
     130 +# Y_Test = np.asarray(Y_Te)
     131 +
     132 +
     133 +class preprocess_data:
     134 + col_names_onehot = ["Dst Port" ,"Protocol"]
     135 + IDS2018All_num = read_data.IDS2018All_is_x.drop(col_names_onehot, axis=1) #pd
     136 + IDS2018All_onehot_s = read_data.IDS2018All_is_x[col_names_onehot] #pd
     137 +
     138 + num_pipeline = Pipeline([('scaling', StandardScaler())])
     139 + cat_string_pipeline = Pipeline([('imputer', SimpleImputer(strategy = "constant", fill_value = "missing")), ('ordi', OrdinalEncoder()), ('onehots', OneHotEncoder(categories='auto'))])
     140 +
     141 + num_attribs = list(IDS2018All_num)
     142 + cat_s_attribs = list(IDS2018All_onehot_s)
     143 +
     144 + full_pipeline = ColumnTransformer([("num", num_pipeline, num_attribs), ("cats", cat_string_pipeline, cat_s_attribs)])
     145 +
     146 + IDS2018All_t = full_pipeline.fit_transform(read_data.IDS2018All_is_x)
     147 +
     148 + X_Train_i = IDS2018All_t[:read_data.IDS2018Train_len]
     149 + X_Test_i = IDS2018All_t[read_data.IDS2018Train_len:read_data.IDS2018Train_len + read_data.IDS2018Test_len]
     150 +
     151 + X_Train = np.expand_dims(X_Train_i, axis=1)
     152 + X_Test = np.expand_dims(X_Test_i, axis=1)
     153 +
     154 +# X_Train = sp.sparse.csr_matrix(X_Train_)
     155 +# X_Test = sp.sparse.csr_matrix(X_Test_)
     156 +
     157 + Y_Train = read_data.Y_Train
     158 + Y_Test = read_data.Y_Test
     159 +
     160 +# Y_Train_np = Y_Train.to_numpy()
     161 +# Y_Test_np = Y_Test.to_numpy()
     162 +
     163 +# Y_Train = np.expand_dims(Y_Train_, axis=1)
     164 +# Y_Test = np.expand_dims(Y_Test_, axis=1)
     165 +
     166 +
     167 +import tensorflow as tf
     168 +from tensorflow.keras.layers import Dense
     169 +from tensorflow.keras.layers import LSTM
     170 +from tensorflow.keras import optimizers
     171 +from tensorflow.keras import models
     172 +from tensorflow.keras import layers
     173 +from tensorflow.keras import wrappers
     174 +from tensorflow.keras import initializers
     175 +from tensorflow.keras import regularizers
     176 +from tensorflow.keras import losses
     177 +from scipy.stats import reciprocal
     178 +from sklearn.model_selection import RandomizedSearchCV
     179 +from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
     180 +from sklearn.metrics import classification_report
     181 +from sklearn.metrics import mean_squared_error
     182 +import time
     183 +
     184 +input_dim = preprocess_data.X_Train.shape[2]
     185 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     186 +print(input_dim)
     187 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     188 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     189 +#hidden_layers = 1
     190 +classes = 2
     191 +loop_back =1
     192 +hidden_encoder_dim = input_dim
     193 +hidden_decoder_dim = input_dim
     194 +
     195 +def build_model(learning_rate, hidden_layers, initiali):
     196 + model = models.Sequential([
     197 + layers.LSTM(hidden_layers, input_shape=(loop_back,input_dim)),
     198 +# layers.LSTM(hidden_layers, input_shape=(loop_back,input_dim), return_sequences=True),
     199 +# layers.BatchNormalization(),
     200 + # layers.Dense(1, activation ='softmax')])
     201 + layers.Dense(1)])
     202 + adamopt = optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, amsgrad=False)
     203 + # model.compile(loss="binary_crossentropy", optimizer=adamopt, metrics=['accuracy'])
     204 + model.compile(loss="binary_crossentropy", optimizer=adamopt)
     205 +# print(model.summary())
     206 + return model
     207 +
     208 +keras_reg = wrappers.scikit_learn.KerasClassifier(build_model)
     209 +
     210 +param_distribs = {"learning_rate": reciprocal(0.0001, 0.0005), "hidden_layers":[1,2,4], "initiali":['glorot_uniform', 'he_uniform'] }
     211 +
     212 +rnd_search_cv = RandomizedSearchCV(keras_reg, param_distribs, cv=5, scoring='accuracy', n_jobs=-1, error_score=1)
     213 +#rnd_search_cv = RandomizedSearchCV(keras_reg, param_distribs, cv=5, scoring='accuracy', error_score=1)
     214 +
     215 +X_Train = preprocess_data.X_Train
     216 +Y_Train = read_data.Y_Train
     217 +
     218 +X_Test = preprocess_data.X_Test
     219 +Y_Test = read_data.Y_Test
     220 +
     221 +batch_s = 2000
     222 +epoches = 250
     223 +ver = 2
     224 +
     225 +start_time = time.time()
     226 +rnd_search_cv.fit(X_Train, Y_Train, batch_size=batch_s, epochs=epoches, verbose=ver)
     227 +pred_time = time.time()
     228 +pred_test = rnd_search_cv.predict(X_Test)
     229 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!TEST PRED STARTED!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     230 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     231 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     232 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     233 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     234 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     235 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     236 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     237 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     238 +print(pred_test)
     239 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     240 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     241 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     242 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     243 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     244 +print(Y_Test)
     245 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     246 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     247 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     248 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     249 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     250 +print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
     251 +print("Estimator time:", pred_time - start_time)
     252 +print("Prediction time:", time.time() - pred_time)
     253 +print("Total time for fit and predict: %s seconds" % (time.time() - start_time))
     254 +print("Classclassification_report: ", classification_report(Y_Test, pred_test))
     255 +print("Best estimator: \n", rnd_search_cv.best_estimator_)
     256 +print("Best score: \n", rnd_search_cv.best_score_)
     257 +print("Best params: \n", rnd_search_cv.best_params_)
     258 +print("Refit time: \n", rnd_search_cv.refit_time_)
     259 +
     260 +from sklearn.metrics import confusion_matrix
     261 +cm = confusion_matrix(Y_Test, pred_test)
     262 +import itertools
     263 +classes = ['attack','normal']
     264 +plt.figure()
     265 +plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
     266 +plt.title('Confusion matrix')
     267 +plt.colorbar()
     268 +tick_marks = np.arange(len(classes))
     269 +plt.xticks(tick_marks, classes, rotation=45)
     270 +plt.yticks(tick_marks, classes)
     271 +print(cm)
     272 +thresh = cm.max() / 2.
     273 +for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
     274 + plt.text(j, i, cm[i, j].round(4),
     275 + horizontalalignment="center",
     276 + color="white" if cm[i, j] > thresh else "black")
     277 +
     278 +plt.tight_layout()
     279 +plt.ylabel('True label')
     280 +plt.xlabel('Predicted label')
     281 +
     282 +plt.savefig("cm_lstm.png")
     283 + 
Please wait...
Page is in error, reload to recover