PyTorch/MNIST: mat1 and mat2 shapes cannot be multiplied (10×784 and 3072×64)

Ask Questions Forum: ask Machine Learning Questions to our readersCategory: PyTorchPyTorch/MNIST: mat1 and mat2 shapes cannot be multiplied (10×784 and 3072×64)
Chris Staff asked 6 months ago

I am creating a PyTorch model that is training on the MNIST dataset, but I am getting the following error.

Traceback (most recent call last):
  File "classic.py", line 66, in 
    outputs = mlp(inputs)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "classic.py", line 26, in forward
    return self.layers(x)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\container.py", line 117, in forward
    input = module(input)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "C:\Users\chris\Anaconda3\envs\pytorch_gpu\lib\site-packages\torch\nn\functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x784 and 3072x64)
9920512it [00:06, 1460884.87it/s]

This is my network.

class MLP(nn.Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Flatten(),
      nn.Linear(32 * 32 * 3, 64),
      nn.ReLU(),
      nn.Linear(64, 32),
      nn.ReLU(),
      nn.Linear(32, 10)
    )

What is going on here?

1 Answers
Best Answer
Chris Staff answered 6 months ago

The issue is the first Linear, where the input shape is wrong. MNIST has a (28, 28, 1) shape instead of a (32, 32, 3) shape. Change it into this and it will work:

      nn.Linear(28 * 28 * 1, 64),

Your Answer

18 + 19 =