From 50c71e7ad2d990f7ea1a9beee263370830499a95 Mon Sep 17 00:00:00 2001
From: ldemarcosm <l.demarcosm@alumnos.upm.es>
Date: Fri, 5 Nov 2021 15:08:33 +0100
Subject: [PATCH] Fixed minor issue with ml model

---
 .../l3_centralizedattackdetectorServiceServicerImpl.py      | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py
index 8bd08d4f3..2f11e1cfb 100644
--- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py
+++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py
@@ -31,6 +31,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
         self.model = rt.InferenceSession(MODEL_FILE)
         self.input_name = self.model.get_inputs()[0].name
         self.label_name = self.model.get_outputs()[0].name
+        self.prob_name = self.model.get_outputs()[1].name
         
 
     def make_inference(self, request):
@@ -48,7 +49,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
                 ]
             ])
 
-        predictions = self.model.run([self.label_name], {self.input_name: x_data.astype(np.float32)})[0]
+        predictions = self.model.run(
+            [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0]
         # Output format
         output_message = {
             "confidence": None,
@@ -63,7 +65,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto
             "time_start": request.time_start,
             "time_end": request.time_end,
         }
-        if predictions[0] >= 0.5:
+        if predictions[0][1] >= 0.5:
             output_message["confidence"] = predictions[0][1]
             output_message["tag_name"] = "Crypto"
             output_message["tag"] = 1
-- 
GitLab