这是一个unet-pytorch的源码,可以训练自己的模型

Overview

Unet:U-Net: Convolutional Networks for Biomedical Image Segmentation目标检测模型在Pytorch当中的实现


目录

  1. 性能情况 Performance
  2. 所需环境 Environment
  3. 注意事项 Attention
  4. 文件下载 Download
  5. 预测步骤 How2predict
  6. 训练步骤 How2train
  7. miou计算 miou
  8. 参考资料 Reference

性能情况

unet并不适合VOC此类数据集,其更适合特征少,需要浅层特征的医药数据集之类的。

训练数据集 权值文件名称 测试数据集 输入图片大小 mIOU
VOC12+SBD unet_voc.pth VOC-Val12 512x512 55.11

所需环境

torch==1.2.0
torchvision==0.4.0

注意事项

unet_voc.pth是基于VOC拓展数据集训练的。
unet_medical.pth是使用示例的细胞分割数据集训练的。
在使用时需要注意区分。

文件下载

训练所需的unet_voc.pth和unet_medical.pth可在百度网盘中下载。
链接: https://pan.baidu.com/s/1AUBpqsSgamoQGEYpNjJg7A 提取码: i3ck

VOC拓展数据集的百度网盘如下:
链接: https://pan.baidu.com/s/1BrR7AUM1XJvPWjKMIy2uEw 提取码: vszf

预测步骤

一、使用预训练权重

a、VOC预训练权重

  1. 下载完库后解压,如果想要利用voc训练好的权重进行预测,在百度网盘或者release下载unet_voc.pth,放入model_data,运行即可预测。
img/street.jpg
  1. 利用video.py可进行摄像头检测。

b、医药预训练权重

  1. 下载完库后解压,如果想要利用医药数据集训练好的权重进行预测,在百度网盘或者release下载unet_medical.pth,放入model_data,修改unet.py中的model_path和num_classes;
_defaults = {
    "model_path"        : 'model_data/unet_voc.pth',
    "model_image_size"  : (512, 512, 3),
    "num_classes"       : 21,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}
  1. 运行即可预测。
img/cell.png

二、使用自己训练的权重

  1. 按照训练步骤训练。
  2. 在unet.py文件里面,在如下部分修改model_path、backbone和num_classes使其对应训练好的文件;model_path对应logs文件夹下面的权值文件
_defaults = {
    "model_path"        : 'model_data/unet_voc.pth',
    "model_image_size"  : (512, 512, 3),
    "num_classes"       : 21,
    "cuda"              : True,
    #--------------------------------#
    #   blend参数用于控制是否
    #   让识别结果和原图混合
    #--------------------------------#
    "blend"             : True
}
  1. 运行predict.py,输入
img/street.jpg
  1. 利用video.py可进行摄像头检测。

训练步骤

一、训练voc数据集

  1. 将我提供的voc数据集放入VOCdevkit中(无需运行voc2unet.py)。
  2. 在train.py中设置对应参数,默认参数已经对应voc数据集所需要的参数了,所以只要修改backbone和model_path即可。
  3. 运行train.py进行训练。

二、训练自己的数据集

  1. 本文使用VOC格式进行训练。
  2. 训练前将标签文件放在VOCdevkit文件夹下的VOC2007文件夹下的SegmentationClass中。
  3. 训练前将图片文件放在VOCdevkit文件夹下的VOC2007文件夹下的JPEGImages中。
  4. 在训练前利用voc2unet.py文件生成对应的txt。
  5. 注意修改train.py的num_classes为分类个数+1。
  6. 运行train.py即可开始训练。

三、训练医药数据集

  1. 下载VGG的预训练权重到model_data下面。
  2. 按照默认参数运行train_medical.py即可开始训练。

miou计算

参考miou计算视频和博客。

Reference

https://github.com/ggyyzm/pytorch_segmentation
https://github.com/bonlime/keras-deeplab-v3-plus

You might also like...
Comments
  • 询问一下预训练的问题

    询问一下预训练的问题

    你好,打扰了。我是想问下主干模型是指的是在下采样过程中使用的vgg吗?如果我不改变上采样是不是就不用使用imagenet训练。然后注销掉model_path=‘’ 以及 if model_path !=‘’这段。然后使用自己的数据集去进行训练。 谢谢大佬!!!!!!。实际上大佬你的voc的权重文件是不是为二次预训练的数据。 不好意思,语言表达能力不行。俺不晓得这样说大佬明不明白。

    opened by Nine9844 5
  • 训练一段时间后,CE loss变为NAN

    训练一段时间后,CE loss变为NAN

    您好,看了您的教程我试着自己搭建了一个U-Net模型,并采用Dice + CE loss作为损失函数,但在迭代几十个epoch后,我的CE loss返回了NAN值,反馈的结果是 ‘Function 'LogSoftmaxBackward' returned nan values in its 0th output.’ 同样的数据在您源码上运行没有出现这个问题,请问您是否知道些解决方法?

    opened by Breeze-Zero 2
  • 为啥在dataloader第40行转换的array的shape和cv2不一样呢

    为啥在dataloader第40行转换的array的shape和cv2不一样呢

    我使用json_to_dataset.py转化mask后尝试使用代码查看shape import cv2 import numpy as np from PIL import Image

    file = '/home/fut/Downloads/unet-pytorch-main/mydata/masks/ID_1110_json.png' img = cv2.imread(file, cv2.IMREAD_UNCHANGED) print(img.shape)

    pil = Image.open(file) img2 = np.array(pil) print(img2.shape) 结果会是: (800, 800, 3) (800, 800) 为什么PIL读取后通道就没了,正是因为这个原因你的项目会很好跑起来。

    opened by futureflsl 1
  • from tqdm import tqdm 报错

    from tqdm import tqdm 报错

    import os import time

    import numpy as np import torch import torch.backends.cudnn as cudnn import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm

    opened by Luke-Wei 1
Releases(v3.0)
  • v3.0(Apr 22, 2022)

    重要更新

    • 支持step、cos学习率下降法。
    • 支持adam、sgd优化器选择。
    • 支持不同预测模式的选择,单张图片预测、文件夹预测、视频预测、图片裁剪。
    • 更新summary.py文件,用于观看网络结构。
    • 增加了多GPU训练。
    Source code(tar.gz)
    Source code(zip)
  • v2.2(Mar 4, 2022)

    重要更新

    • 更新train.py文件,增加了大量的注释,增加多个可调整参数。
    • 更新predict.py文件,增加了大量的注释,增加fps、视频预测、批量预测等功能。
    • 更新unet.py文件,增加了大量的注释,增加先验框选择、置信度、非极大抑制等参数。
    • 合并get_dr_txt.py、get_gt_txt.py和get_map.py文件,通过一个文件来实现数据集的评估。
    • 更新voc_annotation.py文件,增加多个可调整参数。
    • 更新callback.py文件,防止多线程错误。
    • 更新summary.py文件,用于观看网络结构。
    Source code(tar.gz)
    Source code(zip)
Owner
Bubbliiiing
Bubbliiiing
Deep generative modeling for time-stamped heterogeneous data, enabling high-fidelity models for a large variety of spatio-temporal domains.

Neural Spatio-Temporal Point Processes [arxiv] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel Abstract. We propose a new class of parameterizations

Facebook Research 75 Dec 19, 2022
SparseInst: Sparse Instance Activation for Real-Time Instance Segmentation, CVPR 2022

SparseInst 🚀 A simple framework for real-time instance segmentation, CVPR 2022 by Tianheng Cheng, Xinggang Wang†, Shaoyu Chen, Wenqiang Zhang, Qian Z

Hust Visual Learning Team 458 Jan 05, 2023
A Factor Model for Persistence in Investment Manager Performance

Factor-Model-Manager-Performance A Factor Model for Persistence in Investment Manager Performance I apply methods and processes similar to those used

Omid Arhami 1 Dec 01, 2021
code release for USENIX'22 paper `On the Security Risks of AutoML`

This project is a minimized runnable project cut from trojanzoo, which contains more datasets, models, attacks and defenses. This repo will not be mai

Ren Pang 5 Apr 19, 2022
A PyTorch Image-Classification With AlexNet And ResNet50.

PyTorch 图像分类 依赖库的下载与安装 在终端中执行 pip install -r -requirements.txt 完成项目依赖库的安装 使用方式 数据集的准备 STL10 数据集 下载:STL-10 Dataset 存储位置:将下载后的数据集中 train_X.bin,train_y.b

FYH 4 Feb 22, 2022
U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

U^2-Net - Portrait matting This repository explores possibilities of using the original u^2-net model for portrait matting.

Dennis Bappert 104 Nov 25, 2022
Code and real data for the paper "Counterfactual Temporal Point Processes", available at arXiv.

counterfactual-tpp This is a repository containing code and real data for the paper Counterfactual Temporal Point Processes. Pre-requisites This code

Networks Learning 11 Dec 09, 2022
Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks.

Heterogeneous Graph Benchmark Revisiting, benchmarking, and refining Heterogeneous Graph Neural Networks. Roadmap We organize our repo by task, and on

THUDM 176 Dec 17, 2022
Multi-Task Deep Neural Networks for Natural Language Understanding

New Release We released Adversarial training for both LM pre-training/finetuning and f-divergence. Large-scale Adversarial training for LMs: ALUM code

Xiaodong 2.1k Dec 30, 2022
Data and code for ICCV 2021 paper Distant Supervision for Scene Graph Generation.

Distant Supervision for Scene Graph Generation Data and code for ICCV 2021 paper Distant Supervision for Scene Graph Generation. Introduction The pape

THUNLP 23 Dec 31, 2022
A multi-mode modulator for multi-domain few-shot classification (ICCV)

A multi-mode modulator for multi-domain few-shot classification (ICCV)

Yanbin Liu 8 Apr 28, 2022
Unrolled Generative Adversarial Networks

Unrolled Generative Adversarial Networks Luke Metz, Ben Poole, David Pfau, Jascha Sohl-Dickstein arxiv:1611.02163 This repo contains an example notebo

Ben Poole 292 Dec 06, 2022
Fast and Context-Aware Framework for Space-Time Video Super-Resolution (VCIP 2021)

Fast and Context-Aware Framework for Space-Time Video Super-Resolution Preparation Dependencies PyTorch 1.2.0 CUDA 10.0 DCNv2 cd model/DCNv2 bash make

Xueheng Zhang 1 Mar 29, 2022
the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet]

BGNet This repository contains the code for our CVPR 2021 paper Bilateral Grid Learning for Stereo Matching Network [BGNet] Environment Python 3.6.* C

3DCV developer 87 Nov 29, 2022
🌈 PyTorch Implementation for EMNLP'21 Findings "Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer"

SGLKT-VisDial Pytorch Implementation for the paper: Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer Gi-Cheon Kang, Junseok P

Gi-Cheon Kang 9 Jul 05, 2022
An unsupervised learning framework for depth and ego-motion estimation from monocular videos

SfMLearner This codebase implements the system described in the paper: Unsupervised Learning of Depth and Ego-Motion from Video Tinghui Zhou, Matthew

Tinghui Zhou 1.8k Dec 30, 2022
[CoRL 2021] A robotics benchmark for cross-embodiment imitation.

x-magical x-magical is a benchmark extension of MAGICAL specifically geared towards cross-embodiment imitation. The tasks still provide the Demo/Test

Kevin Zakka 36 Nov 26, 2022
ERISHA is a mulitilingual multispeaker expressive speech synthesis framework. It can transfer the expressivity to the speaker's voice for which no expressive speech corpus is available.

ERISHA: Multilingual Multispeaker Expressive Text-to-Speech Library ERISHA is a multilingual multispeaker expressive speech synthesis framework. It ca

Ajinkya Kulkarni 43 Nov 27, 2022
NAS-HPO-Bench-II is the first benchmark dataset for joint optimization of CNN and training HPs.

NAS-HPO-Bench-II API Overview NAS-HPO-Bench-II is the first benchmark dataset for joint optimization of CNN and training HPs. It helps a fair and low-

yoichi hirose 8 Nov 21, 2022
Unsupervised Semantic Segmentation by Contrasting Object Mask Proposals.

Unsupervised Semantic Segmentation by Contrasting Object Mask Proposals This repo contains the Pytorch implementation of our paper: Unsupervised Seman

Wouter Van Gansbeke 335 Dec 28, 2022