From 50c71e7ad2d990f7ea1a9beee263370830499a95 Mon Sep 17 00:00:00 2001 From: ldemarcosm <l.demarcosm@alumnos.upm.es> Date: Fri, 5 Nov 2021 15:08:33 +0100 Subject: [PATCH] Fixed minor issue with ml model --- .../l3_centralizedattackdetectorServiceServicerImpl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index 8bd08d4f3..2f11e1cfb 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -31,6 +31,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.model = rt.InferenceSession(MODEL_FILE) self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name + self.prob_name = self.model.get_outputs()[1].name def make_inference(self, request): @@ -48,7 +49,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ] ]) - predictions = self.model.run([self.label_name], {self.input_name: x_data.astype(np.float32)})[0] + predictions = self.model.run( + [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] # Output format output_message = { "confidence": None, @@ -63,7 +65,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto "time_start": request.time_start, "time_end": request.time_end, } - if predictions[0] >= 0.5: + if predictions[0][1] >= 0.5: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = 1 -- GitLab