diff --git a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py index 8bd08d4f3a276a2566ccaaf7d6db06edd500a8db..2f11e1cfb5d57a50703cb3d44e2b05a7fb1a1b4f 100644 --- a/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +++ b/src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py @@ -31,6 +31,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.model = rt.InferenceSession(MODEL_FILE) self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name + self.prob_name = self.model.get_outputs()[1].name def make_inference(self, request): @@ -48,7 +49,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ] ]) - predictions = self.model.run([self.label_name], {self.input_name: x_data.astype(np.float32)})[0] + predictions = self.model.run( + [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] # Output format output_message = { "confidence": None, @@ -63,7 +65,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto "time_start": request.time_start, "time_end": request.time_end, } - if predictions[0] >= 0.5: + if predictions[0][1] >= 0.5: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = 1