Source code for meerqat.models.rr
# -*- coding: utf-8 -*-
from torch import nn
from transformers import BertModel
from transformers.models.bert import BertPreTrainedModel
from .mm import ECAEncoder, FlamantModel
from .outputs import ReRankerOutput
[docs]class BertReRanker(BertPreTrainedModel):
"""
As described in [1]_.
Almost like BertForSequenceClassification without dropout, and pooling from [CLS] token.
References
----------
.. [1] Zhiguo Wang, Patrick Ng, Xiaofei Ma, Ramesh Nallapati, and Bing Xiang.
2019. Multi-passage BERT: A Globally Normalized BERT Model for Open-domain Question Answering.
In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing
and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP),
pages 5878–5882, Hong Kong, China. Association for Computational Linguistics.
"""
def __init__(self, config):
super().__init__(config)
self.config = config
self.bert = BertModel(config, add_pooling_layer=False)
self.classifier = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
[docs] def forward(self, *args, **kwargs):
outputs = self.bert(*args, **kwargs)
# Pool from [CLS]
pooled_output = outputs.last_hidden_state[:, 0]
logits = self.classifier(pooled_output)
return ReRankerOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)
[docs]class ECAReRanker(ECAEncoder):
"""Like BertReRanker with a ECA backbone instead of BERT"""
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.classifier = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
[docs] def forward(self, *args, return_dict=True, **kwargs):
outputs = super().forward(*args, return_dict=return_dict, **kwargs)
logits = self.classifier(outputs.pooler_output)
return ReRankerOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)
[docs]class FlamantReRanker(FlamantModel):
"""Like BertReRanker with a Flamant backbone instead of BERT"""
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
self.classifier = nn.Linear(config.hidden_size, 1)
# Initialize weights and apply final processing
self.post_init()
[docs] def forward(self, *args, return_dict=True, **kwargs):
outputs = super().forward(*args, return_dict=return_dict, **kwargs)
logits = self.classifier(outputs.pooler_output)
return ReRankerOutput(
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)