| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- (base) [root@localhost ~]# docker exec finetune-trainer bash -c 'sed -n "1400,1450p" /opt/conda/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py'
- warnings.warn(
- "compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
- "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
- )
- compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
- with compute_loss_context_manager():
- loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
- # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
- loss = loss.to(self.args.device)
- # force log the metrics
- self.store_metrics(metrics, train_eval="train")
- if return_outputs:
- return (loss, metrics)
- return loss
- def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
- """Generate samples from the model and reference model for the given batch of inputs."""
- # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
- # the torch cuda amp context manager as some hidden states are silently casted to full precision.
- generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
- with generate_context_manager():
- policy_output = model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.tokenizer.pad_token_id,
- )
- # if reference_output in batch use that otherwise use the reference model
- if "reference_output" in batch:
- reference_output = batch["reference_output"]
- else:
- if self.ref_model is None:
- with self.null_ref_context():
- reference_output = self.model.generate(
- input_ids=batch["prompt_input_ids"],
- attention_mask=batch["prompt_attention_mask"],
- max_length=self.max_length,
- do_sample=True,
- pad_token_id=self.tokenizer.pad_token_id,
- )
- else:
- reference_output = self.ref_model.generate(
- input_ids=batch["prompt_input_ids"],
|