| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm import create_model |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoTokenizer, |
| PretrainedConfig, |
| PreTrainedModel, |
| ) |
| from transformers.utils import ModelOutput |
|
|
| from .location_encoder import LocationEncoder |
|
|
|
|
| class CLOSPConfig(PretrainedConfig): |
| """ |
| Configuration class for CLOSPModel. |
| |
| This class stores the configuration of a CLOSPModel, which is used to instantiate the model |
| according to the specified parameters. |
| """ |
|
|
| model_type = "closp" |
|
|
| def __init__( |
| self, |
| |
| vision_model_key: str = "vit-s", |
| s1_embedding_dim: int = 384, |
| s2_embedding_dim: int = 384, |
| s1_head_dim: int = 0, |
| s2_head_dim: int = 0, |
| |
| text_model_name_or_path: str = "distilbert-base-uncased", |
| |
| use_location_encoder: bool = True, |
| location_embedding_dim: int = 512, |
| |
| projection_dim: int = 768, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.vision_model_key = vision_model_key |
| self.s1_embedding_dim = s1_embedding_dim |
| self.s2_embedding_dim = s2_embedding_dim |
| self.text_model_name_or_path = text_model_name_or_path |
| self.use_location_encoder = use_location_encoder |
| self.location_embedding_dim = location_embedding_dim |
| self.projection_dim = projection_dim |
| self.s1_head_dim = s1_head_dim |
| self.s2_head_dim = s2_head_dim |
|
|
|
|
| |
| @dataclass |
| class CLOSPOutput(ModelOutput): |
| """ |
| Base class for CLOSP model's outputs. |
| """ |
|
|
| loss: torch.FloatTensor = None |
| logits_per_image: torch.FloatTensor = None |
| logits_per_text: torch.FloatTensor = None |
| logits_per_loc_img: torch.FloatTensor = None |
| logits_per_img_loc: torch.FloatTensor = None |
| image_embeds: torch.FloatTensor = None |
| text_embeds: torch.FloatTensor = None |
| location_embeds: torch.FloatTensor = None |
|
|
|
|
| class CLOSPModel(PreTrainedModel): |
| config_class = CLOSPConfig |
|
|
| def __init__(self, config: CLOSPConfig): |
| super().__init__(config) |
| |
| self.s1_encoder = create_model( |
| config.vision_model_key, |
| in_chans=2, |
| num_classes=config.s1_head_dim, |
| pretrained=False, |
| ) |
| self.s2_encoder = create_model( |
| config.vision_model_key, |
| in_chans=13, |
| num_classes=config.s2_head_dim, |
| pretrained=False, |
| ) |
| self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim) |
| self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim) |
|
|
| |
| self.text_model = AutoModel.from_config( |
| AutoConfig.from_pretrained(config.text_model_name_or_path) |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) |
|
|
| |
| if config.use_location_encoder: |
| self.location_encoder = LocationEncoder(512, 2, 256, 10) |
| self.location_projection = nn.Linear( |
| config.location_embedding_dim, config.projection_dim |
| ) |
|
|
| def tokenize_text(self, text: str): |
| """Tokenizes input text using the model's tokenizer.""" |
| return self.tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=self.tokenizer.model_max_length, |
| return_tensors="pt", |
| ) |
|
|
| def get_image_features(self, image: torch.Tensor) -> torch.Tensor: |
| """Encodes an image tensor into features.""" |
| image = image.float() |
| if image.shape[1] == 2: |
| image_features = self.s1_projection(self.s1_encoder(image)) |
| else: |
| image_features = self.s2_projection(self.s2_encoder(image)) |
|
|
| return F.normalize(image_features, p=2, dim=-1) |
|
|
| def get_text_features( |
| self, input_ids: torch.Tensor, attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """Encodes text tokens into features.""" |
| text_outputs = self.text_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| ) |
| text_features = text_outputs.last_hidden_state[:, 0, :] |
| return F.normalize(text_features, p=2, dim=-1) |
|
|
| def get_location_features(self, coords: torch.Tensor) -> torch.Tensor: |
| """Encodes coordinates into features.""" |
| if not self.config.use_location_encoder: |
| raise ValueError( |
| "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config." |
| ) |
| location_features = self.location_encoder(coords) |
| location_features = self.location_projection(location_features) |
| return F.normalize(location_features, p=2, dim=-1) |
|
|
| def forward( |
| self, |
| image: torch.Tensor, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| coords: torch.Tensor = None, |
| return_loss: bool = False, |
| ) -> CLOSPOutput: |
| image_embeds = self.get_image_features(image) |
| text_embeds = self.get_text_features(input_ids, attention_mask) |
|
|
| |
| logits_per_image = image_embeds @ text_embeds.T |
| logits_per_text = logits_per_image.T |
|
|
| |
| location_embeds = None |
| logits_per_loc_img = None |
| logits_per_img_loc = None |
|
|
| if self.config.use_location_encoder: |
| if coords is None: |
| raise ValueError( |
| "Coordinates must be provided when use_location_encoder is True." |
| ) |
| location_embeds = self.get_location_features(coords) |
| logits_per_loc_img = location_embeds @ image_embeds.T |
| logits_per_img_loc = image_embeds @ location_embeds.T |
|
|
| |
| loss = None |
| if return_loss: |
| outputs = [ |
| logits_per_image, |
| logits_per_text, |
| logits_per_loc_img, |
| logits_per_img_loc, |
| ] |
| ground_truth = torch.arange(len(input_ids)).to(self.device) |
| loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None] |
| loss = sum(loss) / len(loss) |
|
|
| return CLOSPOutput( |
| loss=loss, |
| logits_per_image=logits_per_image, |
| logits_per_text=logits_per_text, |
| logits_per_loc_img=logits_per_loc_img, |
| logits_per_img_loc=logits_per_img_loc, |
| image_embeds=image_embeds, |
| text_embeds=text_embeds, |
| location_embeds=location_embeds, |
| ) |
|
|