RuntimeError: mat1 dim 1 must match mat2 dim 0 in PyTorch

Ask Questions Forum: ask Machine Learning Questions to our readersCategory: PyTorchRuntimeError: mat1 dim 1 must match mat2 dim 0 in PyTorch
Chris Staff asked 1 month ago

I am defining the following MLP with PyTorch Lightning:

class MLP(pl.LightningModule):
  
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Linear(32 * 32, 64),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.ReLU(),
      nn.Linear(32, 10)
    )
    self.ce = nn.CrossEntropyLoss()

I want to train it with the CIFAR-10 dataset:

if __name__ == '__main__':
  dataset = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
  pl.seed_everything(42)
  mlp = MLP()
  trainer = pl.Trainer(auto_scale_batch_size='power', gpus=1, deterministic=True, max_epochs=5)
  trainer.fit(mlp, DataLoader(dataset))

I am however getting this error:

  File "lightning.py", line 28, in training_step
    y_hat = self.layers(x)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\container.py", line 117, in forward
    input = module(input)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0

What is the problem?

1 Answers
Best Answer
Chris Staff answered 1 month ago

The CIFAR-10 dataset is a color dataset, meaning that it has 3 channels. The input to your neural network is however 32*32:

      nn.Linear(32 * 32, 64),

If you add the channels, it will work:

      nn.Linear(32 * 32 * 3, 64),

Your Answer

11 + 10 =