Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions opencompass/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,15 @@ def _load_model(self,

def generate(self,
inputs: List[str],
max_out_len: int,
max_out_len: Optional[int],
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]:
"""Generate results given a list of inputs.

Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
max_out_len (Optional[int]): The maximum length of the output.
min_out_len (Optional[int]): The minimum length of the output.

Returns:
Expand All @@ -255,15 +255,15 @@ def generate(self,

def _batch_generate(self,
inputs: List[str],
max_out_len: int,
max_out_len: Optional[int],
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]:
"""Support for batch prompts inference.

Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
max_out_len (Optional[int]): The maximum length of the output.

Returns:
List[str]: A list of generated strings.
Expand Down Expand Up @@ -315,8 +315,10 @@ def _batch_generate(self,
kwargs['min_new_tokens'] = min_out_len

# step-2: conduct model forward to generate output
# Handle max_out_len being None
effective_max_out_len = max_out_len if max_out_len is not None else 512
outputs = self.model.generate(**tokens,
max_new_tokens=max_out_len,
max_new_tokens=effective_max_out_len,
**kwargs)

if not self.extract_pred_after_decode:
Expand All @@ -339,15 +341,15 @@ def _batch_generate(self,

def _single_generate(self,
inputs: List[str],
max_out_len: int,
max_out_len: Optional[int],
min_out_len: Optional[int] = None,
stopping_criteria: List[str] = [],
**kwargs) -> List[str]:
"""Support for single prompt inference.

Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
max_out_len (Optional[int]): The maximum length of the output.

Returns:
List[str]: A list of generated strings.
Expand All @@ -371,19 +373,22 @@ def _single_generate(self,
if self.mode == 'mid':
input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
if len(input_ids[0]) > self.max_seq_len - max_out_len:
half = int((self.max_seq_len - max_out_len) / 2)
effective_max_out_len = (max_out_len if max_out_len is not None
else 0)
if len(input_ids[0]) > self.max_seq_len - effective_max_out_len:
half = int((self.max_seq_len - effective_max_out_len) / 2)
inputs = [
self.tokenizer.decode(input_ids[0][:half],
skip_special_tokens=True) +
self.tokenizer.decode(input_ids[0][-half:],
skip_special_tokens=True)
]

effective_max_out_len = max_out_len if max_out_len is not None else 0
input_ids = self.tokenizer(inputs,
truncation=True,
max_length=self.max_seq_len -
max_out_len)['input_ids']
effective_max_out_len)['input_ids']
input_ids = torch.tensor(input_ids, device=self.model.device)
origin_stopping_criteria = stopping_criteria
if stopping_criteria:
Expand All @@ -406,8 +411,10 @@ def _single_generate(self,

# To accommodate the PeftModel, parameters should be passed in
# key-value format for generate.
# Handle max_out_len being None
effective_max_out_len = max_out_len if max_out_len is not None else 512
outputs = self.model.generate(input_ids=input_ids,
max_new_tokens=max_out_len,
max_new_tokens=effective_max_out_len,
**kwargs)

if not self.extract_pred_after_decode:
Expand Down