Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| #from transformers.modeling_distilbert import DistilBertPreTrainedModel, DistilBertModel, DistilBertConfig | |
| from transformers.models.distilbert.modeling_distilbert import DistilBertPreTrainedModel, DistilBertModel, DistilBertConfig | |
| from torchcrf import CRF | |
| from .module import IntentClassifier, SlotClassifier | |
| class JointDistilBERT(DistilBertPreTrainedModel): | |
| def __init__(self, config, args, intent_label_lst, slot_label_lst): | |
| super(JointDistilBERT, self).__init__(config) | |
| self.args = args | |
| self.num_intent_labels = len(intent_label_lst) | |
| self.num_slot_labels = len(slot_label_lst) | |
| self.distilbert = DistilBertModel(config=config) # Load pretrained bert | |
| self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate) | |
| self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate) | |
| if args.use_crf: | |
| self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True) | |
| def forward(self, input_ids, attention_mask, intent_label_ids, slot_labels_ids): | |
| outputs = self.distilbert(input_ids, attention_mask=attention_mask) # last-layer hidden-state, (hidden_states), (attentions) | |
| sequence_output = outputs[0] | |
| pooled_output = sequence_output[:, 0] # [CLS] | |
| intent_logits = self.intent_classifier(pooled_output) | |
| slot_logits = self.slot_classifier(sequence_output) | |
| total_loss = 0 | |
| # 1. Intent Softmax | |
| if intent_label_ids is not None: | |
| if self.num_intent_labels == 1: | |
| intent_loss_fct = nn.MSELoss() | |
| intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1)) | |
| else: | |
| intent_loss_fct = nn.CrossEntropyLoss() | |
| intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)) | |
| total_loss += intent_loss | |
| # 2. Slot Softmax | |
| if slot_labels_ids is not None: | |
| if self.args.use_crf: | |
| slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean') | |
| slot_loss = -1 * slot_loss # negative log-likelihood | |
| else: | |
| slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index) | |
| # Only keep active parts of the loss | |
| if attention_mask is not None: | |
| active_loss = attention_mask.view(-1) == 1 | |
| active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss] | |
| active_labels = slot_labels_ids.view(-1)[active_loss] | |
| slot_loss = slot_loss_fct(active_logits, active_labels) | |
| else: | |
| slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1)) | |
| total_loss += self.args.slot_loss_coef * slot_loss | |
| outputs = ((intent_logits, slot_logits),) + outputs[1:] # add hidden states and attention if they are here | |
| outputs = (total_loss,) + outputs | |
| return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits | |