41 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			41 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
import torch
 | 
						|
from transformers import PreTrainedModel
 | 
						|
 | 
						|
from ..utils import torch_gc
 | 
						|
 | 
						|
 | 
						|
class CPUTextEncoderWrapper(PreTrainedModel):
 | 
						|
    def __init__(self, text_encoder, torch_dtype):
 | 
						|
        super().__init__(text_encoder.config)
 | 
						|
        self.config = text_encoder.config
 | 
						|
        self._device = text_encoder.device
 | 
						|
        # cpu not support float16
 | 
						|
        self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
 | 
						|
        self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
 | 
						|
        self.torch_dtype = torch_dtype
 | 
						|
        del text_encoder
 | 
						|
        torch_gc()
 | 
						|
 | 
						|
    def __call__(self, x, **kwargs):
 | 
						|
        input_device = x.device
 | 
						|
        original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
 | 
						|
        for k, v in original_output.items():
 | 
						|
            if isinstance(v, tuple):
 | 
						|
                original_output[k] = [
 | 
						|
                    v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
 | 
						|
                ]
 | 
						|
            else:
 | 
						|
                original_output[k] = v.to(input_device).to(self.torch_dtype)
 | 
						|
        return original_output
 | 
						|
 | 
						|
    @property
 | 
						|
    def dtype(self):
 | 
						|
        return self.torch_dtype
 | 
						|
 | 
						|
    @property
 | 
						|
    def device(self) -> torch.device:
 | 
						|
        """
 | 
						|
        `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
 | 
						|
        device).
 | 
						|
        """
 | 
						|
        return self._device |