当前位置:网站首页>Pytoch realizes logistic regression
Pytoch realizes logistic regression
2022-07-26 08:54:00 【Miracle Fan】
Pytorch Realization Logistic Return to
1. The import related api
import torch
import torch.nn as nn
import numpy as np
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
2. Prepare the data
bc = datasets.load_breast_cancer()
X, y = bc.data, bc.target
n_samples, n_features = X.shape# Obtain the number of samples and the number of sample characteristics
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1234)# Set test set to 20%
3. Data preprocessing
3.1 Feature scaling
# Must use first fit_transform(trainData), After that transform(testData)
# If direct transform(testData), The program will report an error
# If fit_transfrom(trainData) after , Use fit_transform(testData) Instead of transform(testData), Although it can also normalize , But the two results are not in the same “ standard ” Under the , There are obvious differences .( Be sure to avoid this situation )
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
3.2 Data type conversion
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
y_train = torch.from_numpy(y_train.astype(np.float32))
y_test = torch.from_numpy(y_test.astype(np.float32))
3.3 Data format adjustment
tensor.view()
Be similar to array.resize()
, Here is to adjust the label line to 1 Column
# Be similar to resize()
y_train = y_train.view(y_train.shape[0], 1)
y_test = y_test.view(y_test.shape[0], 1)
4. Build a model
4.1 Define the basic network model
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()# For initialization nn.module
self.linear = nn.Linear(n_input_features, 1)# Define a linear perceptron , Input is n_input_features, Output as a single value
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))# Use sigmoid Output function 0~1 A value of
return y_pred
model = Model(n_features)
4.2 Define loss and optimizer
num_epochs = 100
lr = 0.01
criterion = nn.BCELoss()# Because this is the second category , So use Binary Cross Entropy
optimizer = torch.optim.SGD(model.parameters(), lr=lr)#lr For learning rate , Hyperparameters
5. model training
for epoch in range(num_epochs):
# Iterate forward and calculate the error
y_pred = model(X_train)
loss = criterion(y_pred, y_train)
# Reverse iteration 、 Update parameters
loss.backward()
optimizer.step()
# Clear the gradient before updating the parameters once
optimizer.zero_grad()
if (epoch+1) % 10 == 0:
print(f'epoch: {
epoch+1}, loss = {
loss.item():.4f}')
6. Model test
with torch.no_grad():
y_predicted = model(X_test)
y_predicted_cls = y_predicted.round()# Yes 0~1 The results between are rounded , Get classification value
acc = y_predicted_cls.eq(y_test).sum() / float(y_test.shape[0])
print(f'accuracy: {
acc.item():.4f}')
y_predicted_cls.eq(y_test)
When the prediction category is equal to the category of test data, it is 1, And count all the times of correct classification , Calculate the accuracy of the test set .
边栏推荐
- Vision Group Training Day5 - machine learning, image recognition project
- 2022年收益率最高的理财产品是哪个?
- Memory management - dynamic partition allocation simulation
- The effective condition of MySQL joint index and the invalid condition of index
- JDBC数据库连接池(Druid技术)
- 数据库操作技能7
- day06 作业--技能题6
- Learning notes of automatic control principle --- linear discrete system
- day06 作业--技能题1
- Cadence (x) wiring skills and precautions
猜你喜欢
ES6模块化导入导出)(实现页面嵌套)
基于C语言设计的换乘指南打印系统
idea快捷键 alt实现整列操作
Media at home and abroad publicize that we should strictly grasp the content
OA项目之我的会议(查询)
【FreeSwitch开发实践】自定义模块创建与使用
PXE principles and concepts
Vision Group Training Day5 - machine learning, image recognition project
Database operation skills 6
sklearn 机器学习基础(线性回归、欠拟合、过拟合、岭回归、模型加载保存)
随机推荐
Solve the problem of C # calling form controls across threads
Xtrabackup appears' flush no '_ WRITE_ TO_ BINLOG TABLES‘: 1205 (HY000) Lock wait timeout exceeded;
基于C语言设计的换乘指南打印系统
Oracle 19C OCP 1z0-082 certification examination question bank (13-18)
IC's first global hacking bonus is up to US $6million, helping developers venture into web 3!
Vision Group Training Day5 - machine learning, image recognition project
C Entry series (31) -- operator overloading
Spark scheduling analysis
有限元学习知识点备案
【搜索专题】看完必会的搜索问题之洪水覆盖
File management file system based on C #
[recommended collection] summary of MySQL 30000 word essence - partitions, tables, databases and master-slave replication (V)
PXE principles and concepts
The largest number of statistical absolute values --- assembly language
Oracle 19C OCP 1z0-082 certification examination question bank (51-60)
Learning notes of automatic control principle --- linear discrete system
One click deployment of lamp and LNMP scripts is worth having
at、crontab
In the first year of L2, the upgrade of arbitrum nitro brought a more compatible and efficient development experience
Which of count (*), count (primary key ID), count (field) and count (1) in MySQL is more efficient? "Suggested collection"