recap_gen/scoring/bert_ranker.py

33 lines
1013 B
Python

from transformers.models.bert import (
BertTokenizer,
BertForSequenceClassification
)
import torch
from typing import List
import pandas as pd
class BERTImportanceRanker:
def __init__(self, model_name: str):
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.model = BertForSequenceClassification.from_pretrained(
model_name, num_labels=2)
self.model.eval()
def score_texts(self, texts: List[str]) -> List[float]:
inputs = self.tokenizer(
texts, padding=True,
truncation=True,
return_tensors="pt",
max_length=64
)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
return probs[:, 1].tolist()
def apply_to_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
scores = self.score_texts(df["text"].tolist())
df["importance_score"] = scores
return df