TypeError: __init__() got an unexpected keyword argument 'train' in PyTorch

Ask Questions Forum: ask Machine Learning Questions to our readersCategory: PyTorchTypeError: __init__() got an unexpected keyword argument 'train' in PyTorch
Chris Staff asked 6 months ago

I am trying to load the predefined train/test split within the PyTorch CIFAR10 dataset as follows:

  # Prepare CIFAR-10 dataset
  dataset = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
  trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1, train=True)
  trainloader_test = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1, train=False)

However, when running the model, I’m getting this error:

  File "classic.py", line 36, in                                                                                                                                                                                                                                                                                         trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1, train=True)                                                                                                                                                                                                               TypeError: __init__() got an unexpected keyword argument 'train'  

Why does this occur?

1 Answers
Best Answer
Chris Staff answered 6 months ago

The train boolean must be added to the CIFAR10 object rather than the DataLoader object 🙂

  # Prepare CIFAR-10 dataset
  dataset = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor(), train=False)
  dataset_test = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor(), train=True)
  trainloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=1)
  trainloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10, shuffle=True, num_workers=1)

Your Answer

2 + 11 =