JAX深度學習
[英] 格裏戈裏·薩普諾夫(Grigory Sapunov)著 殷海英 譯
- 出版商: 清華大學
- 出版日期: 2026-03-01
- 售價: $768
- 語言: 簡體中文
- ISBN: 7302710473
- ISBN-13: 9787302710479
-
相關分類:
DeepLearning
- 此書翻譯自: Deep Learning with Jax (Paperback)
下單後立即進貨 (約4週~6週)
商品描述
作者簡介
目錄大綱
目 錄
第Ⅰ部分 入門基礎
第1章 何時以及為什麼使用JAX 3
1.1 使用JAX的理由 6
1.1.1 計算性能 7
1.1.2 函數式方法 9
1.1.3 JAX生態系統 10
1.2 JAX與NumPy的區別 11
1.2.1 JAX作為NumPy 12
1.2.2 可組合的變換 12
1.3 JAX與TensorFlow和PyTorch的區別 14
1.4 本章小結 16
第2章 你的第一個JAX程序 17
2.1 一個簡單的機器學習問題:手寫數字分類 18
2.2 JAX深度學習項目概覽 19
2.3 加載和準備數據集 20
2.4 在JAX中構建一個簡單的神經網絡 22
2.4.1 神經網絡初始化 24
2.4.2 神經網絡前向傳播 25
2.5 vmap:自動向量化計算以支持批處理 28
2.6 自動微分:如何在不手動計算導數的情況下計算梯度 30
2.6.1 損失函數 32
2.6.2 獲取梯度 33
2.6.3 梯度更新步驟 33
2.6.4 訓練循環 34
2.7 JIT:將代碼編譯為更快的版本 36
2.8 保存和部署模型 37
2.9 純函數和可組合的轉換:它們為什麼重要 39
2.10 練習 40
2.11 本章小結 40
第Ⅱ部分 JAX核心機制
第3章 數組操作 43
3.1 使用NumPy數組進行圖像處理 44
3.1.1 將圖像加載到NumPy數組中 45
3.1.2 對圖像執行基本預處理操作 48
3.1.3 向圖像添加噪聲 50
3.1.4 實現圖像濾波 51
3.1.5 將張量保存為圖像文件 55
3.2 JAX中的數組 57
3.2.1 切換到JAX類似NumPy的API 57
3.2.2 什麼是Array 58
3.2.3 與設備相關的操作 61
3.2.4 異步調度 65
3.2.5 在TPU上運行計算 66
3.3 與NumPy的區別 69
3.3.1 不可變性 70
3.3.2 類型 73
3.4 高級接口與低級接口:jax.numpy和jax.lax 76
3.4.1 控制流原語 77
3.4.2 類型提升 78
3.5 練習 79
3.6 本章小結 79
第4章 計算梯度 81
4.1 獲取導數的不同方法 82
4.1.1 手動求導 83
4.1.2 符號微分 84
4.1.3 數值微分 85
4.1.4 自動微分 87
4.2 使用自動微分計算梯度 89
4.2.1 在TensorFlow中使用梯度 91
4.2.2 在PyTorch中使用梯度 92
4.2.3 在JAX中使用梯度 92
4.2.4 高階導數 100
4.2.5 多變量情況 102
4.3 前向模式和反向模式自動微分 104
4.3.1 計算軌跡 105
4.3.2 前向模式和jvp() 106
4.3.3 反向模式和vjp() 110
4.3.4 深入探索 113
4.4 本章小結 114
第5章 編譯代碼 115
5.1 使用編譯 116
5.1.1 使用JIT編譯 117
5.1.2 純函數與編譯過程 123
5.2 JIT內部機制 125
5.2.1 Jaxpr:JAX程序的中間表示形式 125
5.2.2 XLA 134
5.2.3 使用AOT編譯 138
5.3 JIT的局限性 142
5.3.1 純函數與非純函數 142
5.3.2 精確數值 142
5.3.3 輸入參數值依賴的條件控制 142
5.3.4 編譯速度慢 142
5.3.5 類方法 144
5.3.6 簡單函數 145
5.4 練習 146
5.5 本章小結 146
第6章 向量化代碼 147
6.1 向量化函數的不同方法 148
6.1.1 樸素方法 149
6.1.2 手動向量化 151
6.1.3 自動向量化 151
6.1.4 性能比較 152
6.2 控制vmap()行為 154
6.2.1 控制映射的數組軸 154
6.2.2 控制輸出數組的軸 157
6.2.3 使用命名參數 158
6.2.4 使用裝飾器風格 160
6.2.5 使用集合操作 161
6.3 vmap()的實際應用案例 162
6.3.1 批數據處理 162
6.3.2 批量化神經網絡模型 164
6.3.3 每個樣本的梯度 165
6.3.4 向量化循環 166
6.4 本章小結 169
第7章 並行化計算 171
7.1 使用pmap()並行化計算 172
7.1.1 問題設置 172
7.1.2 像使用vmap一樣使用pmap 175
7.2 控制pmap()的行為 180
7.2.1 控制輸入和輸出的映射軸 181
7.2.2 使用命名軸和集合操作 186
7.3 數據並行的神經網絡訓練示例 193
7.3.1 準備數據和神經網絡結構 194
7.3.2 實現數據並行訓練 196
7.4 使用多主機配置 201
7.5 本章小結 206
第8章 使用張量切分 209
8.1 張量分片基礎 210
8.1.1 設備網格 212
8.1.2 位置分片 213
8.1.3 二維網格示例 213
8.1.4 使用復制 217
8.1.5 分片約束 219
8.1.6 命名切分 221
8.1.7 設備放置策略與錯誤 222
8.2 使用張量分片的多層感知機(MLP) 224
8.2.1 八路數據並行 224
8.2.2 四路數據並行,雙路張量並行 226
8.3 本章小結 229
第9章 JAX中的隨機數 231
9.1 生成隨機數據 232
9.1.1 載入數據集 233
9.1.2 生成隨機噪聲 235
9.1.3 執行隨機增強 239
9.2 與NumPy的區別 241
9.2.1 NumPy的工作原理 241
9.2.2 NumPy中的種子和狀態 243
9.2.3 JAX PRNG 246
9.2.4 JAX PRNG高級配置 252
9.3 在實際應用中生成隨機數 253
9.3.1 構建一個完整的數據增強管道 253
9.3.2 為神經網絡生成隨機初始化 255
9.4 本章小結 256
第10章 處理pytree 257
10.1 將復雜數據結構表示為pytree 258
10.2 處理pytree的函數 262
10.2.1 使用tree_map() 263
10.2.2 扁平化/還原pytree 265
10.2.3 使用tree_reduce() 267
10.2.4 轉置pytree 268
10.3 創建自定義pytree節點 271
10.4 本章小結 274
第Ⅲ部分 生態系統
第11章 高級神經網絡庫 277
11.1 使用MLP進行MNIST圖像分類 278
11.1.1 Flax中的MLP 278
11.1.2 Optax梯度變換庫 284
11.1.3 使用Flax訓練神經網絡 286
11.2 使用ResNet進行圖像分類 290
11.2.1 在Flax中管理狀態 290
11.2.2 使用Orbax保存和加載模型 296
11.3 使用Hugging Face生態系統 298
11.3.1 使用Hugging Face Model Hub中的預訓練模型 299
11.3.2 進一步探索:微調與再訓練 304
11.3.3 使用diffusers庫 306
11.4 本章小結 310
第12章 JAX生態系統的其他成員 313
12.1 深度學習生態系統 314
12.1.1 高層神經網絡庫 314
12.1.2 JAX中的大型語言模型(LLM) 315
12.1.3 工具庫 317
12.2 機器學習模塊 319
12.2.1 強化學習 319
12.2.2 其他機器學習庫 320
12.3 其他領域的JAX模塊 321
12.4 本章小結 322
附錄A 安裝JAX 325
附錄B 使用Google Colab 329
附錄C 使用Google Cloud TPU 331
附錄D 實驗性並行化 335



