33 lines
1013 B
Python
33 lines
1013 B
Python
from transformers.models.bert import (
|
|
BertTokenizer,
|
|
BertForSequenceClassification
|
|
)
|
|
import torch
|
|
from typing import List
|
|
import pandas as pd
|
|
|
|
|
|
class BERTImportanceRanker:
|
|
def __init__(self, model_name: str):
|
|
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
|
self.model = BertForSequenceClassification.from_pretrained(
|
|
model_name, num_labels=2)
|
|
self.model.eval()
|
|
|
|
def score_texts(self, texts: List[str]) -> List[float]:
|
|
inputs = self.tokenizer(
|
|
texts, padding=True,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
max_length=64
|
|
)
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
|
return probs[:, 1].tolist()
|
|
|
|
def apply_to_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
|
|
scores = self.score_texts(df["text"].tolist())
|
|
df["importance_score"] = scores
|
|
return df
|