1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
| def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, BaseModelOutputWithPast]: orig_embeds_params = getattr(self, 'orig_embeds_params', None)
if inputs_embeds is None and past_key_values is None: inputs_embeds = self.embed_tokens(input_ids)
vision_tower = getattr(self, 'vision_tower', None) if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: if type(images) is list: image_features = [] for image in images: image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[0] image_features.append(image_forward_out) else: image_features = self.get_vision_embedding(images)
dummy_image_features = torch.zeros( self.config.num_query, self.config.hidden_size, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
new_input_embeds = [] cur_image_idx = 0 for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds): if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0: cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() new_input_embeds.append(cur_input_embeds) continue
if self.vision_config.use_im_start_end: cur_image_features = image_features[cur_image_idx] num_patches = cur_image_features.shape[0] if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum(): raise ValueError( "The number of image start tokens and image end tokens should be the same.") image_start_tokens = torch.where( cur_input_ids == self.vision_config.im_start_token)[0] for image_start_token_pos in image_start_tokens: cur_image_features = image_features[cur_image_idx].to( device=cur_input_embeds.device) num_patches = cur_image_features.shape[0] if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token: raise ValueError( "The image end token should follow the image start token.") if orig_embeds_params is not None: cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0) else: cur_new_input_embeds = torch.cat( (cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) cur_image_idx += 1 new_input_embeds.append(cur_new_input_embeds) else: raise NotImplementedError inputs_embeds = torch.stack(new_input_embeds, dim=0) input_ids = None
return super(OmniLMMModel, self).forward( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs )
|