ValueError: à l'Aide d'une cible de taille (de la torche.Taille([2, 1])) qui est différente de la taille de saisie (torche.Taille([16, 1])) est obsolète

0

La question

Je suis en train de construire un modèle pour la Quora questions paire dataset où la sortie est binaire 1 ou 0, mais j'ai cette erreur. Je sais que la sortie de la forme de mon modèle est différent de l'entrée de la forme, mais je ne sais pas comment le résoudre. La taille des lots est fixé à 16

    class Bert_model (nn.Module):
      def __init__(self) :
        super(Bert_model,self).__init__()
        self.bert =  BertModel.from_pretrained('bert-base-uncased', return_dict=False)
        self.drop_layer = nn.Dropout(.25)
        self.output = nn.Linear(self.bert.config.hidden_size,1)
    
      def forward(self,input_ids,attention_mask):
        _,o2 = self.bert (input_ids =input_ids , attention_mask = attention_mask )
        o2 = self.drop_layer(o2)
        return self.output(o2)

    model = Bert_model()
    
    loss_fn = nn.BCELoss().to(device)

    def train_epoch(
      model, 
      data_loader, 
      loss_fn, 
      optimizer, 
      device, 
      n_examples
    ):
      model = model.train()
    
      losses = []
      correct_predictions = 0
      
      for d in data_loader:
        input_ids = d["input_ids"].to(device)
        attention_mask = d["attention_mask"].to(device)
        targets = d["target"].to(device)
    
        input_ids = input_ids.view(BATCH_SIZE,-1)
        attention_mask = attention_mask.view(BATCH_SIZE,-1)
    
        outputs = model(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
    
        _, preds = torch.max(outputs, dim=1)
    
        targets = targets.unsqueeze(-1)
        loss = loss_fn(F.softmax(outputs,dim=1), targets)
    
        correct_predictions += torch.sum(preds == targets)
        losses.append(loss.item())
    
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()
    
      return correct_predictions.double() / n_examples, np.mean(losses)

Le message d'erreur:

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in
binary_cross_entropy(input, target, weight, size_average, reduce,
reduction)    2913         weight = weight.expand(new_size)    2914 
-> 2915     return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)    2916     2917  ValueError: Using a target
size (torch.Size([2, 1])) that is different to the input size
(torch.Size([16, 1])) is deprecated
deep-learning pytorch
2021-11-21 11:25:25
1

La meilleure réponse

0

À partir de la trace de la pile, l'erreur se produit dans le BCELoss calculs, cela est dû au fait que la outputs.shape est (16, 1)tandis que targets.shape est (2, 1).

Je vois un problème majeur dans votre code: BCELoss est utilisé pour comparer les distributions de probabilité (vérifier les docs), mais votre sortie du modèle est de la forme (n, 1)n est la taille du lot (dans votre cas, 16). En fait, dans l'instruction de retour de forward vous passez o2 pour un linéaire de la couche dont la sortie est de la forme 1.

La Question Paires Dataset est destiné pour la classification binaire des tâches, de sorte que vous besoin de convertir votre sortie dans une distribution de probabilité, par exemple, à l'aide d'un Sigmoid ou réglage linéaire de la couche de la taille de la sortie 2, puis à l'aide de la softmax.

2021-11-21 15:50:29

En outre, vous pouvez passer BCELoss avec CrossEntropyLossqui est destiné pour les problèmes de classification.
aretor

- je changer la perte de fonction (BCEWithLogitsLoss), qui s'applique sigmoïde à la sortie , puis j'ai enlevé le softmax . le problème existe toujours, mais maintenant, parce que la taille de la cible est (10,1) et les différents à partir d'une entrée (16,1)
BuzzedHub

Il est difficile de dire que l'erreur à partir de votre code. Étant donné que 16 est la bonne taille de lot, vérifiez bien quand votre cible les changements de taille de 16 à 10. Veuillez éviter de modifier le corps de votre question, les réponses ne seront pas plus aucun sens.
aretor

Dans d'autres langues

Cette page est dans d'autres langues

Русский
..................................................................................................................
Italiano
..................................................................................................................
Polski
..................................................................................................................
Română
..................................................................................................................
한국어
..................................................................................................................
हिन्दी
..................................................................................................................
Türk
..................................................................................................................
Česk
..................................................................................................................
Português
..................................................................................................................
ไทย
..................................................................................................................
中文
..................................................................................................................
Español
..................................................................................................................
Slovenský
..................................................................................................................