这是一个unet-pytorch的源码,可以训练自己的模型

Overview

Unet:U-Net: Convolutional Networks for Biomedical Image Segmentation目标检测模型在Pytorch当中的实现


目录

  1. 性能情况 Performance
  2. 所需环境 Environment
  3. 注意事项 Attention
  4. 文件下载 Download
  5. 预测步骤 How2predict
  6. 训练步骤 How2train
  7. miou计算 miou
  8. 参考资料 Reference

性能情况

unet并不适合VOC此类数据集,其更适合特征少,需要浅层特征的医药数据集之类的。

训练数据集 权值文件名称 测试数据集 输入图片大小 mIOU
VOC12+SBD unet_voc.pth VOC-Val12 512x512 55.11

所需环境

torch==1.2.0
torchvision==0.4.0

注意事项

unet_voc.pth是基于VOC拓展数据集训练的。
unet_medical.pth是使用示例的细胞分割数据集训练的。
在使用时需要注意区分。

文件下载

训练所需的unet_voc.pth和unet_medical.pth可在百度网盘中下载。
链接: https://pan.baidu.com/s/1AUBpqsSgamoQGEYpNjJg7A 提取码: i3ck

VOC拓展数据集的百度网盘如下:
链接: https://pan.baidu.com/s/1BrR7AUM1XJvPWjKMIy2uEw 提取码: vszf

预测步骤

一、使用预训练权重

a、VOC预训练权重

  1. 下载完库后解压,如果想要利用voc训练好的权重进行预测,在百度网盘或者release下载unet_voc.pth,放入model_data,运行即可预测。
img/street.jpg
  1. 利用video.py可进行摄像头检测。

b、医药预训练权重

  1. 下载完库后解压,如果想要利用医药数据集训练好的权重进行预测,在百度网盘或者release下载unet_medical.pth,放入model_data,修改unet.py中的model_path和num_classes;
_defaults = {
    "model_path"        : 'model_data/unet_voc.pth',
    "model_image_size"  : (512, 512, 3),
    "num_classes"       : 21,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}
  1. 运行即可预测。
img/cell.png

二、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在unet.py文件里面,在如下部分修改model_path、backbone和num_classes使其对应训练好的文件;model_path对应logs文件夹下面的权值文件
_defaults = {
    "model_path"        : 'model_data/unet_voc.pth',
    "model_image_size"  : (512, 512, 3),
    "num_classes"       : 21,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}
  1. 运行predict.py,输入
img/street.jpg
  1. 利用video.py可进行摄像头检测。

训练步骤

一、训练voc数据集

  1. 将我提供的voc数据集放入VOCdevkit中(无需运行voc2unet.py)。
  2. 在train.py中设置对应参数,默认参数已经对应voc数据集所需要的参数了,所以只要修改backbone和model_path即可。
  3. 运行train.py进行训练。

二、训练自己的数据集

  1. 本文使用VOC格式进行训练。
  2. 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。
  3. 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
  4. 在训练前利用voc2unet.py文件生成对应的txt。
  5. 注意修改train.py的num_classes为分类个数+1。
  6. 运行train.py即可开始训练。

三、训练医药数据集

  1. 下载VGG的预训练权重到model_data下面。
  2. 按照默认参数运行train_medical.py即可开始训练。

miou计算

参考miou计算视频和博客。

Reference

https://github.com/ggyyzm/pytorch_segmentation
https://github.com/bonlime/keras-deeplab-v3-plus

You might also like...
Comments
  • 询问一下预训练的问题

    询问一下预训练的问题

    你好,打扰了。我是想问下主干模型是指的是在下采样过程中使用的vgg吗?如果我不改变上采样是不是就不用使用imagenet训练。然后注销掉model_path=‘’ 以及 if model_path !=‘’这段。然后使用自己的数据集去进行训练。 谢谢大佬!!!!!!。实际上大佬你的voc的权重文件是不是为二次预训练的数据。 不好意思,语言表达能力不行。俺不晓得这样说大佬明不明白。

    opened by Nine9844 5
  • 训练一段时间后,CE loss变为NAN

    训练一段时间后,CE loss变为NAN

    您好,看了您的教程我试着自己搭建了一个U-Net模型,并采用Dice + CE loss作为损失函数,但在迭代几十个epoch后,我的CE loss返回了NAN值,反馈的结果是 ‘Function 'LogSoftmaxBackward' returned nan values in its 0th output.’ 同样的数据在您源码上运行没有出现这个问题,请问您是否知道些解决方法?

    opened by Breeze-Zero 2
  • 为啥在dataloader第40行转换的array的shape和cv2不一样呢

    为啥在dataloader第40行转换的array的shape和cv2不一样呢

    我使用json_to_dataset.py转化mask后尝试使用代码查看shape import cv2 import numpy as np from PIL import Image

    file = '/home/fut/Downloads/unet-pytorch-main/mydata/masks/ID_1110_json.png' img = cv2.imread(file, cv2.IMREAD_UNCHANGED) print(img.shape)

    pil = Image.open(file) img2 = np.array(pil) print(img2.shape) 结果会是: (800, 800, 3) (800, 800) 为什么PIL读取后通道就没了,正是因为这个原因你的项目会很好跑起来。

    opened by futureflsl 1
  • from tqdm import tqdm 报错

    from tqdm import tqdm 报错

    import os import time

    import numpy as np import torch import torch.backends.cudnn as cudnn import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm

    opened by Luke-Wei 1
Releases(v3.0)
  • v3.0(Apr 22, 2022)

    重要更新

    • 支持step、cos学习率下降法。
    • 支持adam、sgd优化器选择。
    • 支持不同预测模式的选择,单张图片预测、文件夹预测、视频预测、图片裁剪。
    • 更新summary.py文件,用于观看网络结构。
    • 增加了多GPU训练。
    Source code(tar.gz)
    Source code(zip)
  • v2.2(Mar 4, 2022)

    重要更新

    • 更新train.py文件,增加了大量的注释,增加多个可调整参数。
    • 更新predict.py文件,增加了大量的注释,增加fps、视频预测、批量预测等功能。
    • 更新unet.py文件,增加了大量的注释,增加先验框选择、置信度、非极大抑制等参数。
    • 合并get_dr_txt.py、get_gt_txt.py和get_map.py文件,通过一个文件来实现数据集的评估。
    • 更新voc_annotation.py文件,增加多个可调整参数。
    • 更新callback.py文件,防止多线程错误。
    • 更新summary.py文件,用于观看网络结构。
    Source code(tar.gz)
    Source code(zip)
Owner
Bubbliiiing
Bubbliiiing
A curated (most recent) list of resources for Learning with Noisy Labels

A curated (most recent) list of resources for Learning with Noisy Labels

Jiaheng Wei 321 Jan 09, 2023
Continual World is a benchmark for continual reinforcement learning

Continual World Continual World is a benchmark for continual reinforcement learning. It contains realistic robotic tasks which come from MetaWorld. Th

41 Dec 24, 2022
Distributed Evolutionary Algorithms in Python

DEAP DEAP is a novel evolutionary computation framework for rapid prototyping and testing of ideas. It seeks to make algorithms explicit and data stru

Distributed Evolutionary Algorithms in Python 4.9k Jan 05, 2023
PyTorch-based framework for Deep Hedging

PFHedge: Deep Hedging in PyTorch PFHedge is a PyTorch-based framework for Deep Hedging. PFHedge Documentation Neural Network Architecture for Efficien

139 Dec 30, 2022
[NeurIPS-2020] Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID.

Self-paced Contrastive Learning (SpCL) The official repository for Self-paced Contrastive Learning with Hybrid Memory for Domain Adaptive Object Re-ID

Yixiao Ge 286 Dec 21, 2022
(CVPR2021) Kaleido-BERT: Vision-Language Pre-training on Fashion Domain

Kaleido-BERT: Vision-Language Pre-training on Fashion Domain Mingchen Zhuge*, Dehong Gao*, Deng-Ping Fan#, Linbo Jin, Ben Chen, Haoming Zhou, Minghui

248 Dec 04, 2022
LSTMs (Long Short Term Memory) RNN for prediction of price trends

Price Prediction with Recurrent Neural Networks LSTMs BTC-USD price prediction with deep learning algorithm. Artificial Neural Networks specifically L

5 Nov 12, 2021
Source code and Dataset creation for the paper "Neural Symbolic Regression That Scales"

NeuralSymbolicRegressionThatScales Pytorch implementation and pretrained models for the paper "Neural Symbolic Regression That Scales", presented at I

35 Nov 25, 2022
CL-Gym: Full-Featured PyTorch Library for Continual Learning

CL-Gym: Full-Featured PyTorch Library for Continual Learning CL-Gym is a small yet very flexible library for continual learning research and developme

Iman Mirzadeh 36 Dec 25, 2022
DropNAS: Grouped Operation Dropout for Differentiable Architecture Search

DropNAS: Grouped Operation Dropout for Differentiable Architecture Search DropNAS, a grouped operation dropout method for one-level DARTS, with better

weijunhong 4 Aug 15, 2022
Pytorch implementation of the paper "Class-Balanced Loss Based on Effective Number of Samples"

Class-balanced-loss-pytorch Pytorch implementation of the paper Class-Balanced Loss Based on Effective Number of Samples presented at CVPR'19. Yin Cui

Vandit Jain 697 Dec 29, 2022
Open source person re-identification library in python

Open-ReID Open-ReID is a lightweight library of person re-identification for research purpose. It aims to provide a uniform interface for different da

Tong Xiao 1.3k Jan 01, 2023
DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting

DenseCLIP: Language-Guided Dense Prediction with Context-Aware Prompting Created by Yongming Rao*, Wenliang Zhao*, Guangyi Chen, Yansong Tang, Zheng Z

Yongming Rao 322 Dec 31, 2022
Systemic Evolutionary Chemical Space Exploration for Drug Discovery

SECSE SECSE: Systemic Evolutionary Chemical Space Explorer Chemical space exploration is a major task of the hit-finding process during the pursuit of

64 Dec 16, 2022
Neural models of common sense. 🤖

Unicorn on Rainbow Neural models of common sense. This repository is for the paper: Unicorn on Rainbow: A Universal Commonsense Reasoning Model on a N

AI2 60 Jan 05, 2023
Planner_backend - Academic planner application designed for students and counselors.

Planner (backend) Academic planner application designed for students and advisors.

2 Dec 31, 2021
CMT: Convolutional Neural Networks Meet Vision Transformers

CMT: Convolutional Neural Networks Meet Vision Transformers [arxiv] 1. Introduction This repo is the CMT model which impelement with pytorch, no refer

FlyEgle 83 Dec 30, 2022
ReLoss - Official implementation for paper "Relational Surrogate Loss Learning" ICLR 2022

Relational Surrogate Loss Learning (ReLoss) Official implementation for paper "R

Tao Huang 31 Nov 22, 2022
Unofficial Implementation of MLP-Mixer in TensorFlow

mlp-mixer-tf Unofficial Implementation of MLP-Mixer [abs, pdf] in TensorFlow. Note: This project may have some bugs in it. I'm still learning how to i

Rishabh Anand 24 Mar 23, 2022
This is an official source code for implementation on Extensive Deep Temporal Point Process

Extensive Deep Temporal Point Process This is an official source code for implementation on Extensive Deep Temporal Point Process, which is composed o

Haitao Lin 8 Aug 15, 2022