Loading src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +4 −2 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.model = rt.InferenceSession(MODEL_FILE) self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name def make_inference(self, request): Loading @@ -48,7 +49,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ] ]) predictions = self.model.run([self.label_name], {self.input_name: x_data.astype(np.float32)})[0] predictions = self.model.run( [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] # Output format output_message = { "confidence": None, Loading @@ -63,7 +65,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto "time_start": request.time_start, "time_end": request.time_end, } if predictions[0] >= 0.5: if predictions[0][1] >= 0.5: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = 1 Loading Loading
src/l3_centralizedattackdetector/service/l3_centralizedattackdetectorServiceServicerImpl.py +4 −2 Original line number Diff line number Diff line Loading @@ -31,6 +31,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto self.model = rt.InferenceSession(MODEL_FILE) self.input_name = self.model.get_inputs()[0].name self.label_name = self.model.get_outputs()[0].name self.prob_name = self.model.get_outputs()[1].name def make_inference(self, request): Loading @@ -48,7 +49,8 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto ] ]) predictions = self.model.run([self.label_name], {self.input_name: x_data.astype(np.float32)})[0] predictions = self.model.run( [self.prob_name], {self.input_name: x_data.astype(np.float32)})[0] # Output format output_message = { "confidence": None, Loading @@ -63,7 +65,7 @@ class l3_centralizedattackdetectorServiceServicerImpl(L3Centralizedattackdetecto "time_start": request.time_start, "time_end": request.time_end, } if predictions[0] >= 0.5: if predictions[0][1] >= 0.5: output_message["confidence"] = predictions[0][1] output_message["tag_name"] = "Crypto" output_message["tag"] = 1 Loading