509 Views
November 20, 25
スライド概要
テンソルの形状とデータ型に対しても型チェックを行えるjaxtypingというpythonライブラリを紹介しています。
Computer Vision Engineer
jaxtypingの紹介 2025.10.02 小林 茂樹 GOドライブ株式会社 AI Community
01|モチベーション 項目 02|型アノテーションの書き方 03|実行時型チェック 04|まとめと所感 2
01 モチベーション 3
shapeが分からない辛さ ▪ NumpyやPyTorchでテンソルを扱う際に次元(shape)が コロコロ変わる ▪ コードを読んだだけではshapeが分からず、具体的な処理 が掴めないことがある ▪ 変数名や処理の背景からshapeを推測しながらコードを読 むことが多いかと思います ▪ 例:torch.tensorの画像なら (B C H W) だろうと推測する shapeの推測には限界がある 4
shapeの推測が困難な例 1. poseという変数がある 2. 位置姿勢があると推測 3. 姿勢の表現方法によってshapeが違う a. オイラーのshape:(3,) b. クォータ二オンのshape:(4,) c. 回転行列のshape:(3,3) 変数名だけではshapeの推測が困難 5
shapeが分からない辛さ 以下のようなshapeに対してのコメントやアサーションを書いたことは ありませんか? 6
jaxtypingのススメ shapeに対してコメントを書いても良いが、ちゃんと型 チェックしたい jaxtypingを使いましょう! 7
jaxtyping shape, dtypeに対しても型アノテーションを行うライブラリ ▪ リポジトリ:https://github.com/patrick-kidger/jaxtyping ▪ ドキュメント:https://docs.kidger.site/jaxtyping/ ▪ 名前的にJAXしか対応していなさそうだが、 PyTorch/NumPy/TensorFlowにも対応 ▪ 実行時型チェックもできる 8
02 型アノテーションの書き方 9
例 https://docs.kidger.site/jaxtyping/#example 10
基本形 基本は以下のような型になる dtype[array, shape] 例えばPyTorchで画像を扱う場合 11
array 基本は以下のような型になる dtype[array, shape] 例えばPyTorchで画像を扱う場合 12
array 複数のライブラリのarrayを使用することが可能 ▪ ▪ ▪ ▪ ▪ jax.Array np.ndarray torch.Tensor tf.Tensor mx.array https://docs.kidger.site/jaxtyping/api/array/#array 13
shape 基本は以下のような型になる dtype[array, shape] 例えばPyTorchで画像を扱う場合 14
shapeの書き方 str型でshapeを指定した場合、変数となる これらの変数は関数内で一貫した値にならなければエラーになる int型で指定することで定数であることを明示できる 3チャネルの224x224の画像の場合 https://docs.kidger.site/jaxtyping/api/array/#shape 15
shapeの書き方 ローカル・インスタンス変数などの値を次元数として使う こともできる インスタンス変数の値を 次元数として設定 https://docs.kidger.site/jaxtyping/api/array/#shape 16
shapeの書き方 変数を用いた計算からshapeを設定できる dimから1引いた次元数 をshapeに指定 https://docs.kidger.site/jaxtyping/api/array/#shape 17
shapeの書き方 アンダースコアを設定すれば該当の次元は型チェックを避け られる チャネルの次元は 型チェックされない https://docs.kidger.site/jaxtyping/api/array/#shape 18
shapeの書き方 =を使うとコメントとなる(型チェックはされない) shapeに対して型チェック は行われない https://docs.kidger.site/jaxtyping/api/array/#shape 19
shapeの書き方 *を付けた軸は連続する0個以上の軸にできる shapeが (N B C H W) の場合、以下のように書ける https://docs.kidger.site/jaxtyping/api/array/#shape 20
shapeの書き方 ... を0個以上の軸とすることができる shapeが (N B C H W) の場合、以下のように書ける https://docs.kidger.site/jaxtyping/api/array/#shape 21
dtype 基本は以下のような型になる dtype[array, shape] 例えばPyTorchで画像を扱う場合 22
dtype 基本的なデータ型が用意されている ▪ ▪ ▪ ▪ Bool Float Int UInt https://docs.kidger.site/jaxtyping/api/array/#dtype 23
dtype 特定の精度のデータ型もある ▪ Float16 ▪ Float32 ▪ Float64 ▪ BFloat16 ▪ ▪ ▪ ▪ ▪ ▪ Int2 Int4 Int8 Int16 Int32 Int64 ▪ ▪ ▪ ▪ ▪ ▪ UInt2 UInt4 UInt8 UInt16 UInt32 UInt64 https://docs.kidger.site/jaxtyping/api/array/#dtype 24
03 実行時型チェック 25
hookの種類 実行時型チェックを呼び出すhook3種類 ▪ 関数単位でのhook ▪ モジュール単位でのhook ▪ pytestに対してのhook https://docs.kidger.site/jaxtyping/api/runtime-type-checking/ 26
関数単位でのhook 実行時型チェックを行いたい関数に対してjaxtypedというデコレータを設定する データクラスの場合は__init__の中に対して実行時型チェックが行われる https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#jaxtyping.jaxtyped 27
モジュール単位でのhook install_import_hookのwithブロック内でimportする このブロック内でimportされたモジュールの全関数に対して実行時型チェックが 適用される https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#jaxtyping.install_import_hook 28
pytestに対してのhook pytestに追加で引数を与えれば良い jaxtyping-packagesに実行時型チェックを適用したいモジュールを指定する 最もキレイにまとまるが、テストに対してしか実行時型チェックができない pyproject.tomlで設定する場合 pytest.iniで設定する場合 https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#pytest-hook 29
(余談)pytorch-lightningを用いたテスト pytorch-lightningを用いている場合はfast_dev_runというパラメーターを設定 することでtrainingとvalidationを1度だけ実行することができる ➡ ユニットテストでtraining、validationに対してのjaxtypingの型チェックを少 ない計算量で行うことができる https://lightning.ai/docs/pytorch/stable/common/trainer.html#fast-dev-run 30
04 まとめと所感 31
改善されると嬉しいところ IDEのホバー情報にshape、dtypeに対しての型情報が出てこない shape、dtypeについて の情報がない 32
改善されると嬉しいところ ログが長い ▪ テンソルの内容が全てログに出てきて本質的なログが分かりづらい エラーログ サンプルコード 想定している返り値の shapeではないので エラーになる エラー毎にテンソルの 生値が出てくるので、 エラーが冗長になる 33
まとめと所感 NumpyやPyTorchなどのarrayに対してshapeとdtypeも含めて型アノテーション と型チェックができるライブラリ ▪ jaxtypingを導入したリポジトリを約1年ぶりに触ったが、可読性が上がって おり理解を早めた(と思う) ▪ コードを見ただけでshapeが分かるのはかなり嬉しい ▪ もっと流行ってほしい 34