> ## Documentation Index
> Fetch the complete documentation index at: https://wb-21fd5541-sdk-testing-latest.mintlify.site/llms.txt
> Use this file to discover all available pages before exploring further.

# PyTorch torchtune

> PyTorch torchtuneでW&B loggingを使用して、WandBLoggerメトリクスロガーでLLMのファインチューニング実験をトラッキングします。

export const ColabLink = ({url}) => <a href={url} target="_blank" rel="noopener noreferrer" className="colab-link">
    <svg width="20" height="20" viewBox="0 0 24 24" fill="currentColor" xmlns="http://www.w3.org/2000/svg">
      <path d="M14.25.18l.9.2.73.26.59.3.45.32.34.34.25.34.16.33.1.3.04.26.02.2-.01.13V8.5l-.05.63-.13.55-.21.46-.26.38-.3.31-.33.25-.35.19-.35.14-.33.1-.3.07-.26.04-.21.02H8.77l-.69.05-.59.14-.5.22-.41.27-.33.32-.27.35-.2.36-.15.37-.1.35-.07.32-.04.27-.02.21v3.06H3.17l-.21-.03-.28-.07-.32-.12-.35-.18-.36-.26-.36-.36-.35-.46-.32-.59-.28-.73-.21-.88-.14-1.05-.05-1.23.06-1.22.16-1.04.24-.87.32-.71.36-.57.4-.44.42-.33.42-.24.4-.16.36-.1.32-.05.24-.01h.16l.06.01h8.16v-.83H6.18l-.01-2.75-.02-.37.05-.34.11-.31.17-.28.25-.26.31-.23.38-.2.44-.18.51-.15.58-.12.64-.1.71-.06.77-.04.84-.02 1.27.05zm-6.3 1.98l-.23.33-.08.41.08.41.23.34.33.22.41.09.41-.09.33-.22.23-.34.08-.41-.08-.41-.23-.33-.33-.22-.41-.09-.41.09zm13.09 3.95l.28.06.32.12.35.18.36.27.36.35.35.47.32.59.28.73.21.88.14 1.04.05 1.23-.06 1.23-.16 1.04-.24.86-.32.71-.36.57-.4.45-.42.33-.42.24-.4.16-.36.09-.32.05-.24.02-.16-.01h-8.22v.82h5.84l.01 2.76.02.36-.05.34-.11.31-.17.29-.25.25-.31.24-.38.2-.44.17-.51.15-.58.13-.64.09-.71.07-.77.04-.84.01-1.27-.04-1.07-.14-.9-.2-.73-.25-.59-.3-.45-.33-.34-.34-.25-.34-.16-.33-.1-.3-.04-.25-.02-.2.01-.13v-5.34l.05-.64.13-.54.21-.46.26-.38.3-.32.33-.24.35-.2.35-.14.33-.1.3-.06.26-.04.21-.02.13-.01h5.84l.69-.05.59-.14.5-.21.41-.28.33-.32.27-.35.2-.36.15-.36.1-.35.07-.32.04-.28.02-.21V6.07h2.09l.14.01.21.03zm-6.47 14.25l-.23.33-.08.41.08.41.23.33.33.23.41.08.41-.08.33-.23.23-.33.08-.41-.08-.41-.23-.33-.33-.23-.41-.08-.41.08z" />
    </svg>
    Colabで試す
  </a>;

<ColabLink url="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/torchtune/torchtune_and_wandb.ipynb" />

[torchtune](https://meta-pytorch.org/torchtune/stable/index.html) は、大規模言語モデル (LLM) の作成、ファインチューニング、実験を効率化するために設計された、PyTorch ベースのライブラリです。さらに、torchtune は [W\&B へのログ記録](https://meta-pytorch.org/torchtune/stable/deep_dives/wandb_logging.html) を標準でサポートしており、トレーニング過程のトラッキングと可視化を強化します。

<Frame>
  <img src="https://mintcdn.com/wb-21fd5541-sdk-testing-latest/5BwwFpNAnQO_33rW/images/integrations/torchtune_dashboard.png?fit=max&auto=format&n=5BwwFpNAnQO_33rW&q=85&s=bec3b54a93c202869782109571187c8a" alt="TorchTuneのトレーニングダッシュボード" width="1942" height="1286" data-path="images/integrations/torchtune_dashboard.png" />
</Frame>

[torchtune を使った Mistral 7B のファインチューニング](https://wandb.ai/capecape/torchtune-mistral/reports/torchtune-The-new-PyTorch-LLM-fine-tuning-library---Vmlldzo3NTUwNjM0) に関する W\&B のブログ記事をご覧ください。

<div id="wb-logging-at-your-fingertips">
  ## すぐ使える W\&B logging
</div>

<Tabs>
  <Tab title="コマンドライン">
    起動時にコマンドライン引数を上書きします。

    ```bash theme={null}
    tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
      metric_logger._component_=torchtune.utils.metric_logging.WandBLogger \
      metric_logger.project="llama3_lora" \
      log_every_n_steps=5
    ```
  </Tab>

  <Tab title="レシピ">
    レシピの設定で W\&B logging を有効にします。

    ```yaml theme={null}
    # llama3/8B_lora_single_device.yaml 内
    metric_logger:
      _component_: torchtune.utils.metric_logging.WandBLogger
      project: llama3_lora
    log_every_n_steps: 5
    ```
  </Tab>
</Tabs>

<div id="use-the-wb-metric-logger">
  ## W\&Bメトリクスロガーを使用する
</div>

レシピの設定ファイル内の `metric_logger` セクションを変更して、W\&B logging を有効にします。`_component_` を `torchtune.utils.metric_logging.WandBLogger` クラスに変更してください。`project` 名や `log_every_n_steps` を渡して、logging の動作をカスタマイズすることもできます。

また、[wandb.init()](/ja/models/ref/python/functions/init) method に渡すのと同様に、そのほかの `kwargs` も渡せます。たとえば、チームで作業している場合は、`entity` 引数を `WandBLogger` クラスに渡してチーム名を指定できます。

<Tabs>
  <Tab title="レシピ">
    ```yaml theme={null}
    # llama3/8B_lora_single_device.yaml 内
    metric_logger:
      _component_: torchtune.utils.metric_logging.WandBLogger
      project: llama3_lora
      entity: my_project
      job_type: lora_finetune_single_device
      group: my_awesome_experiments
    log_every_n_steps: 5
    ```
  </Tab>

  <Tab title="コマンドライン">
    ```shell theme={null}
    tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
      metric_logger._component_=torchtune.utils.metric_logging.WandBLogger \
      metric_logger.project="llama3_lora" \
      metric_logger.entity="my_project" \
      metric_logger.job_type="lora_finetune_single_device" \
      metric_logger.group="my_awesome_experiments" \
      log_every_n_steps=5
    ```
  </Tab>
</Tabs>

<div id="what-is-logged">
  ## 何がログされますか？
</div>

ログされたメトリクスは、W\&B ダッシュボードで確認できます。デフォルトでは、W\&B は設定ファイル内のすべてのハイパーパラメーターと Launch のオーバーライドをログします。

W\&B は、解決済みの設定を **Overview** タブに記録します。また、その設定は YAML 形式で [Files tab](https://wandb.ai/capecape/torchtune/runs/joyknwwa/files) にも保存されます。

<Frame>
  <img src="https://mintcdn.com/wb-21fd5541-sdk-testing-latest/5BwwFpNAnQO_33rW/images/integrations/torchtune_config.png?fit=max&auto=format&n=5BwwFpNAnQO_33rW&q=85&s=5a57c919a541c370e88af0fea41bea4b" alt="TorchTune の設定" width="1806" height="1362" data-path="images/integrations/torchtune_config.png" />
</Frame>

<div id="logged-metrics">
  ### ログされたメトリクス
</div>

各レシピには、それぞれ独自のトレーニングループがあります。どのメトリクスがログされるかは各レシピごとに異なりますが、デフォルトでは次のメトリクスが含まれます。

| Metric              | Description                                                                                                                  |
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------- |
| `loss`              | モデルの損失                                                                                                                       |
| `lr`                | 学習率                                                                                                                          |
| `tokens_per_second` | モデルの1秒あたりのトークン数                                                                                                              |
| `grad_norm`         | モデルの勾配ノルム                                                                                                                    |
| `global_step`       | トレーニングループ内の現在のstepに対応します。勾配累積が考慮されるため、基本的にはoptimizerのstepが実行されるたびに更新されます。つまり、モデルは `gradient_accumulation_steps` ごとに1回更新されます。 |

<Note>
  `global_step` はトレーニングstep数そのものではありません。これはトレーニングループ内の現在のstepに対応します。勾配累積が考慮されるため、基本的にはoptimizerのstepが実行されるたびに `global_step` は1増加します。たとえば、dataloaderに10個のバッチがあり、gradient accumulation stepsが2で、3エポック実行する場合、optimizerは15回stepを実行します。この場合、`global_step` は1から15までの値を取ります。
</Note>

torchtuneのシンプルな設計により、custom metricsを簡単に追加したり、既存のメトリクスを変更したりできます。対応する [レシピファイル](https://github.com/meta-pytorch/torchtune/tree/main/recipes) を修正するだけで十分です。たとえば、`current_epoch` を総エポック数に対する割合として計算し、次のようにログできます。

```python theme={null}
# レシピファイル内の `train.py` の関数内
self._metric_logger.log_dict(
    {"current_epoch": self.epochs * self.global_step / self._steps_per_epoch},
    step=self.global_step,
)
```

<Note>
  このライブラリは急速に進化しており、現在のメトリクスは変更される可能性があります。カスタムメトリクスを追加する場合は、レシピを修正し、対応する `self._metric_logger.*` 関数を呼び出してください。
</Note>

<div id="save-and-load-checkpoints">
  ## チェックポイントの保存と読み込み
</div>

torchtune ライブラリは、さまざまな[チェックポイント形式](https://meta-pytorch.org/torchtune/stable/deep_dives/checkpointer.html)をサポートしています。使用しているモデルの取得元に応じて、適切な[checkpointer クラス](https://meta-pytorch.org/torchtune/stable/deep_dives/checkpointer.html)に切り替える必要があります。

モデル チェックポイントを[W\&B Artifacts](/ja/models/artifacts/)に保存したい場合、最も簡単な方法は、対応するレシピ内の `save_checkpoint` 関数をオーバーライドすることです。

以下は、`save_checkpoint` 関数をオーバーライドして、モデル チェックポイントを W\&B Artifacts に保存する方法の例です。

```python theme={null}
def save_checkpoint(self, epoch: int) -> None:
    ...
    ## チェックポイントをW&Bに保存する
    ## Checkpointerクラスによってファイル名が異なる
    ## full_finetuneの場合の例
    checkpoint_file = Path.joinpath(
        self._checkpointer._output_dir, f"torchtune_model_{epoch}"
    ).with_suffix(".pt")
    wandb_artifact = wandb.Artifact(
        name=f"torchtune_model_{epoch}",
        type="model",
        # モデル チェックポイントの説明
        description="Model checkpoint",
        # dictとして任意のメタデータを追加できる
        metadata={
            utils.SEED_KEY: self.seed,
            utils.EPOCHS_KEY: self.epochs_run,
            utils.TOTAL_EPOCHS_KEY: self.total_epochs,
            utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
        },
    )
    wandb_artifact.add_file(checkpoint_file)
    wandb.log_artifact(wandb_artifact)
```
