46 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
			
		
		
	
	
			46 lines
		
	
	
		
			1.6 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
from torch import nn
 | 
						|
from .RNN import SequenceEncoder, Im2Seq, Im2Im
 | 
						|
from .RecMv1_enhance import MobileNetV1Enhance
 | 
						|
 | 
						|
from .RecCTCHead import CTCHead
 | 
						|
 | 
						|
backbone_dict = {"MobileNetV1Enhance":MobileNetV1Enhance}
 | 
						|
neck_dict = {'SequenceEncoder': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
 | 
						|
head_dict = {'CTCHead':CTCHead}
 | 
						|
 | 
						|
 | 
						|
class RecModel(nn.Module):
 | 
						|
    def __init__(self, config):
 | 
						|
        super().__init__()
 | 
						|
        assert 'in_channels' in config, 'in_channels must in model config'
 | 
						|
        backbone_type = config.backbone.pop('type')
 | 
						|
        assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'
 | 
						|
        self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)
 | 
						|
 | 
						|
        neck_type = config.neck.pop('type')
 | 
						|
        assert neck_type in neck_dict, f'neck.type must in {neck_dict}'
 | 
						|
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)
 | 
						|
 | 
						|
        head_type = config.head.pop('type')
 | 
						|
        assert head_type in head_dict, f'head.type must in {head_dict}'
 | 
						|
        self.head = head_dict[head_type](self.neck.out_channels, **config.head)
 | 
						|
 | 
						|
        self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'
 | 
						|
 | 
						|
    def load_3rd_state_dict(self, _3rd_name, _state):
 | 
						|
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
 | 
						|
        self.neck.load_3rd_state_dict(_3rd_name, _state)
 | 
						|
        self.head.load_3rd_state_dict(_3rd_name, _state)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        x = self.backbone(x)
 | 
						|
        x = self.neck(x)
 | 
						|
        x = self.head(x)
 | 
						|
        return x
 | 
						|
 | 
						|
    def encode(self, x):
 | 
						|
        x = self.backbone(x)
 | 
						|
        x = self.neck(x)
 | 
						|
        x = self.head.ctc_encoder(x)
 | 
						|
        return x
 |