First, please make sure your pytorch version is above 1.6. Then run the train.py, such as
$ python train.py --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --overwrite --data_dir path-to-your-data
- Python >= 3.6
- PyTorch >= 1.6
- CUDA
- Numpy
- Test Accuracy(%) on CIFAR10
# labels | 250 | 1000 | 4000 |
---|---|---|---|
Multi-Head Co-Training | 4.98±0.30 | 4.74±0.16 | 3.84±0.09 |
- Test Accuracy(%) on CIFAR10 with only 60% know classes
# labels | 50 | 100 | 400 |
---|---|---|---|
Multi-Head Co-Training | 5.8±0.9 | 5.3±0.9 | 4.4±0.9 |
Part of codes in this repository are modified from: