当前位置:网站首页>torch. nn. Use of ctcloss()
torch. nn. Use of ctcloss()
2022-07-18 03:37:00 【Chen Zhuangshi's programming life】
List of articles
1. brief introduction
CTC(Connectionist Temporal Classification) It is called in Chinese “ Continuous time series classification ”, Main solution Label and output The problem of misalignment .
advantage : Do not force alignment of labels , That is, the side length of the label . Just input the sequence and supervise the tag sequence to train .
Application scenarios : Scene text recognition 、 speech recognition 、 Handwriting font recognition and other scenes .
2. Use
Step1: obtain CTCLoss object
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
Parameter description :
(1) blank: Where the blank label is located label value , The default is 0, It needs to be set according to the actual label definition ;
When we predict the text , There is usually a blank character , Whole blank Represents the position of white space characters in the total character set .
(2) reduction: Handle output losses The way ,string type , Optional ’none’ 、 ‘mean’ And ‘sum’,'none’ Said to output losses Do nothing ,‘mean’ On the other hand output losses ( That is, the whole output batch_size The loss of doing operations ) Average processing ,‘sum’ That's right. output losses Summation processing , The default is ’mean’ .
Step2: Call in iteration CTCLoss Calculate the loss value
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
Parameter description :
(1)log_probs: shape=(T, N, C) Model output tensor ,T: Indicates the length of the output sequence ; N: Express batch_size value ; C: Represents the total length of all character sets to be predicted with blank labels .
Such as :shape = (50, 32, 5000), Among them 50 Indicates that an image has at most 50 A word , 32 by batch_size, 5000 The character set representing the entire data set is 5000 individual .
notes : log_probs It usually needs to go through torch.nn.functional.log_softmax After treatment, it is sent to CTCLoss in .
(2)targets: shape=(N, S) or (sum(target_lengths)) Tensor . For the first type ,N Express batch_size, S Indicates the label length . Such as :shape =(32, 50), Among them 32 by batch_size, 50 Indicates that each label has 50 Characters .
For the second type , Is the sum of all labels . But it should be noted that ,targets Cannot contain blank labels .
(3)input_lengths: shape by (N) Tensor or tuple of , But the length of each element must be equal to T That is, the length of the output sequence , Generally speaking, when the output sequence of the model is fixed, the element values of the tensor or tuple are the same ;
(4)target_lengths: shape by (N) Tensor or tuple of , Each of its elements indicates the tag length of each training input sequence , But the label length can be changed ;
Such as : target_lengths = [23, 34,32, … , 45, 34], The label length of the first picture is 23 Characters , The first 2 The label length of the picture is 34 Characters .
Step3: example “CTCLoss Application in license plate recognition ”
(1) Character set :CHARS
CHARS = [' Beijing ', ' Shanghai ', ' tianjin ', ' chongqing ', ' Ji ', ' Jin ', ' Mongolia ', ' liao ', ' ji ', ' black ',
' Sue ', ' Zhejiang ', ' Wan ', ' Fujian ', ' Gan ', ' Lu ', ' Yu ', ' E ', ' hunan ', ' guangdong ',
' guangxi ', ' Joan ', ' sichuan ', ' your ', ' cloud ', ' hidden ', ' shan ', ' gump ', ' green ', ' ning ',
' new ',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-'
]
(2) obtain CTCLoss object
Because the location of the blank label is len(CHARS)-1, And we need to deal with CTCLoss output losses By ‘mean’, You need to initialize... As follows CTCLoss class :
ctc_loss = nn.CTCLoss(blank=len(CHARS)-1, reduction='mean')
We set the length of the output sequence T by 18, Training batch size N by 4 And the training data set is only 4 License plate ( For the sake of illustration ) as follows , Total character set length C Like above CHARS As shown in the for 68:
(3)CTCLoss Explanation of input
Then we print each input parameter in one training iteration and get the following results :
1) log_probs Due to the large number of values and the forward output result of neural network , We only print their shape come out , as follows :
torch.Size([18, 4, 68])
2) Print targets as follows , The training label representing these four license plates , according to target_lengths After the labels are divided, the four license plates can be represented respectively :
tensor([18, 45, 33, 37, 40, 49, 63, 4, 54, 51, 34, 53, 37, 38, 22, 56, 37, 38,33, 39, 34, 46, 2, 41, 44, 37, 39, 35, 33, 40]).
common 30 A digital , because , The actual length of the license plate number in the above figure is :(7, 8, 8, 7), common 30 Characters .
3) Print target_lengths as follows , Each element specifies a sequential fetch targets How many elements represent a license plate, that is, a label :
(7, 7, 8, 8)
4) Print input_lengths as follows , Due to the length of the output sequence T Has been set to 18, Therefore, its elements are fixed and the same :
(18, 18, 18, 18)
among , As long as the model configuration is fixed ,log_probs We don't need to assemble it and then transfer it to CTCLoss, But the other three input parameters need to be based on the actual data set and C、T、N The setting of !
3. What to pay attention to
3.1 The official routine is as follows , But in practical application, we need to log_probs Of detach() Get rid of , Otherwise, it is impossible to conduct back propagation training ;
Such as :
>>> ctc_loss = nn.CTCLoss()
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
>>> loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()
3.2 blank The blank label must be set according to the position of the blank character in the predicted total character set , Otherwise it will go wrong ;
3.3 targets It is recommended to shape Set to (sum(target_lengths)), And then by target_lengths Just specify the length of the input sequence , This is because if set to (N, S), Because S If the label length is variable , Then the length of the first dimension of the two-dimensional tensor we assembled is only min(S) Will lose part of the tag value ( The length of each row of multidimensional array must be consistent ), This makes it impossible for the model to predict long tags ;
3.4 Output sequence length T Try to consider the longest sequence that the model needs to predict when designing the model , If the longest sequence needs to be predicted, its length is I, In theory T Should be greater than or equal to 2I+1, This is because CTCLoss Suppose that in the worst case, there is at least one blank label before and after each real label to distinguish duplicates ;
3.5 Output log_probs In addition to log_softmax() Process and then send to CTCLoss Outside , You must also adjust the order of dimensions , Make sure it shape by (T, N, C)!
notes :
With reference :[ Add link description ](https://zhuanlan.zhihu.com/p/67415439)
Have a certain self understanding of it .
边栏推荐
- torch.nn.CTCLoss()的使用
- PD server grpc interface diagram
- JS to realize the function of electronic signature
- Musk suspends twitter acquisition, with a breakup fee of $1billion
- leetcode:300. 最长递增子序列【LIS板子 + 贪心二分 + nlogn】
- Carte écologique numérique des ressources humaines en Chine - marché flexible de l'emploi
- 聊聊异步编程的 7 种实现方式
- Issue 99: flutter learning (2)
- Redis connection pool
- Dataarts studio data architecture -- a case of intelligent automation pipeline construction based on model driven
猜你喜欢

曾经,我对着AI客服喷了两分钟,它只回复了我的第一句话

Zhihu Gaozan: Data Center -- Alibaba, Daas
![[visdom drawing] summary of visdom drawing in deep learning](/img/1d/534eb0d1c0f7108d8cb959bd66c65d.png)
[visdom drawing] summary of visdom drawing in deep learning

在我旁边的同事突然晋升美团P7,却是只因偷偷学习了这份JVM笔记?

PD server grpc interface diagram

真的牛b!京东T3-2都还在学的微服务+MySQL+Kafka+boot2.x+虚拟机PDF

C # use tooltip control to realize bubble prompt

【visdom绘图】深度学习中Visdom绘图的总结

Fuxin software appeared at the 2022 national chemical enterprise digital intelligence transformation and Development Forum

DEVKIT-mpc5744p配置rtos
随机推荐
学习总结笔记6(阁瑞钛伦特软件-九耶实训)
Efficient development of harmonyos course applications based on ETS
Flutter 中的 offstage
Equal subtrees on binary trees
"The faster the code is written, the slower the program runs"
Flutter中的Wrap
在线办公,如何让协同更高效?
Tencent Dajia Sharing | Tencent alloxio (DOP) in Financial scene Landing and optimization practice
[MySQL] multi table query
Flutter中的IndexedStack
Talking about some features of improving work efficiency supported by slack channel
Tikv & tiflash accelerates complex business queries
中国人力资源数字化生态图谱-灵活用工市场
Shangyun boutique | cloud store helps Huike promote the combination of schools and enterprises and jointly cultivate talents
Award winning research | openeuler developer experience research questionnaire
聊聊异步编程的 7 种实现方式
Digital ecological map of human resources in China - flexible employment market
What is the key in defi, smart contract?
JS to realize the function of electronic signature
Event preview | Apache Doris x Apache seatunnel joint meetup to start registration!