aaljabari commited on
Commit
7255762
·
verified ·
1 Parent(s): f0d8d6d

Add event argument extraction

Browse files
Files changed (1) hide show
  1. main.py +180 -0
main.py CHANGED
@@ -599,6 +599,186 @@ def predict_re(request: RERequest):
599
  except Exception as e:
600
  return {"error": str(e)}
601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  # =========== Front End =============================
603
  from fastapi.staticfiles import StaticFiles
604
  from fastapi.responses import FileResponse
 
599
  except Exception as e:
600
  return {"error": str(e)}
601
 
602
+
603
+ # ============ Event Argument Extraction ==============
604
+ from transformers import pipeline
605
+
606
+ EVENT_MODEL_ID = "SinaLab/arabic-relation-extraction-model"
607
+ EVENT_MAX_LEN = 128
608
+
609
+ event_pipe = pipeline(
610
+ "sentiment-analysis",
611
+ model=EVENT_MODEL_ID,
612
+ tokenizer=EVENT_MODEL_ID,
613
+ device=0 if torch.cuda.is_available() else -1,
614
+ return_all_scores=True,
615
+ max_length=EVENT_MAX_LEN,
616
+ truncation=True
617
+ )
618
+
619
+ event_relation_prompt = {
620
+ "location": "مكان حدوث",
621
+ "agent": "أحد المتأثرين في",
622
+ "happened at": "تاريخ حدوث"
623
+ }
624
+
625
+ event_categories = {
626
+ "agent": ["PERS", "NORP", "OCC", "ORG"],
627
+ "location": ["LOC", "FAC", "GPE"],
628
+ "happened at": ["DATE", "TIME"]
629
+ }
630
+
631
+ event_relation_name_map = {
632
+ "agent": "hasAgent",
633
+ "location": "hasLocation",
634
+ "happened at": "hasDate"
635
+ }
636
+
637
+
638
+ def get_entity_category(entity_type, categories):
639
+ for category, types in categories.items():
640
+ if entity_type in types:
641
+ return category
642
+ return None
643
+
644
+
645
+ def get_positive_score(predicted_relation):
646
+ """
647
+ The pipeline returns something like:
648
+ [
649
+ [
650
+ {"label": "LABEL_0", "score": 0.12},
651
+ {"label": "LABEL_1", "score": 0.88}
652
+ ]
653
+ ]
654
+
655
+ In your original code, you used:
656
+ predicted_relation[0][0]["score"]
657
+
658
+ If your positive class is LABEL_0, keep index 0.
659
+ If your positive class is LABEL_1, change this to index 1.
660
+
661
+ This version first tries LABEL_1, then falls back to index 0.
662
+ """
663
+
664
+ scores = predicted_relation[0]
665
+
666
+ for item in scores:
667
+ if item["label"] in ["LABEL_1", "relation", "RELATION", "positive"]:
668
+ return item["score"]
669
+
670
+ return scores[0]["score"]
671
+
672
+
673
+ def event_argument_extractor(sentence):
674
+ entities = entities_and_types(sentence)
675
+
676
+ event_entities = [
677
+ (entity_name, entity_type)
678
+ for entity_name, entity_type in entities.items()
679
+ if entity_type == "EVENT"
680
+ ]
681
+
682
+ argument_entities = [
683
+ (entity_name, entity_type)
684
+ for entity_name, entity_type in entities.items()
685
+ if entity_type != "EVENT"
686
+ ]
687
+
688
+ output_list = []
689
+
690
+ for event_entity, event_type in event_entities:
691
+ for arg_name, arg_type in argument_entities:
692
+ category = get_entity_category(arg_type, event_categories)
693
+
694
+ if category not in event_relation_prompt:
695
+ continue
696
+
697
+ relation_sentence = (
698
+ f"[CLS] {sentence} [SEP] "
699
+ f"{event_entity} {event_relation_prompt[category]} {arg_name}"
700
+ )
701
+
702
+ predicted_relation = event_pipe(relation_sentence)
703
+ score = score = predicted_relation[0][0]["score"] #get_positive_score(predicted_relation)
704
+
705
+ if score > 0.0:
706
+ output_list.append({
707
+ "Subject": {
708
+ "Type": event_type,
709
+ "Label": event_entity
710
+ },
711
+ "Relation": event_relation_name_map[category],
712
+ "Object": {
713
+ "Type": arg_type,
714
+ "Label": arg_name
715
+ },
716
+ "Confidence": float(round(score, 4))
717
+ })
718
+
719
+ return output_list
720
+
721
+
722
+ class EAERequest(BaseModel):
723
+ text: str
724
+
725
+
726
+ @app.post("/predict_eae")
727
+ def predict_eae(request: EAERequest):
728
+ try:
729
+ text = request.text.strip()
730
+
731
+ if not text:
732
+ return JSONResponse(
733
+ content={
734
+ "resp": [],
735
+ "statusText": "EMPTY_INPUT",
736
+ "statusCode": 1,
737
+ },
738
+ media_type="application/json",
739
+ status_code=200,
740
+ )
741
+
742
+ sentences = sentence_tokenizer(
743
+ text,
744
+ dot=False,
745
+ new_line=True,
746
+ question_mark=False,
747
+ exclamation_mark=False
748
+ )
749
+
750
+ results = []
751
+
752
+ for sentence in sentences:
753
+ sentence = sentence.strip()
754
+ if not sentence:
755
+ continue
756
+
757
+ sentence_results = event_argument_extractor(sentence)
758
+ results.extend(sentence_results)
759
+
760
+ return JSONResponse(
761
+ content={
762
+ "resp": results,
763
+ "statusText": "OK",
764
+ "statusCode": 0,
765
+ },
766
+ media_type="application/json",
767
+ status_code=200,
768
+ )
769
+
770
+ except Exception as e:
771
+ return JSONResponse(
772
+ content={
773
+ "resp": [],
774
+ "statusText": str(e),
775
+ "statusCode": 500,
776
+ },
777
+ media_type="application/json",
778
+ status_code=500,
779
+ )
780
+
781
+
782
  # =========== Front End =============================
783
  from fastapi.staticfiles import StaticFiles
784
  from fastapi.responses import FileResponse