JAX深度學習

[英] 格裏戈裏·薩普諾夫(Grigory Sapunov)著 殷海英 譯

  • JAX深度學習-preview-1
  • JAX深度學習-preview-2
  • JAX深度學習-preview-3
JAX深度學習-preview-1

商品描述

"谷歌的JAX為深度學習提供了全新的技術視角。這款功能強大的庫能讓開發者精確控制梯度計算等底層流程,在大規模數據集上實現快速、高效的模型訓練與推理。JAX徹底改變了科研人員開展深度學習的思路與方式。如今,JAX已擁有強大的工具和庫生態系統,讓各類應用程序都能輕松應對進化計算、聯邦學習以及其他對性能要求極高的任務。 《JAX深度學習》將教你如何利用JAX構建高效的神經網絡。在這本內容豐富的佳作中,你將了解JAX的獨特功能,攻克深度學習中的關鍵性能挑戰,例如,如何在TPU集群上進行分布式計算;通過創建圖像分類工具、圖像過濾應用等實際項目,讓你親身實踐這個庫的用法。書中的代碼示例經過精心註解,展示了JAX的函數式編程思維如何提升代碼的可組合性和並行化能力。 核心亮點: ? 運用JAX進行數值計算 ? 基於JAX原語構建可微分模型 ? 使用JAX進行分布式和並行計算 ? 使用高級神經網絡庫,如Flax"

作者簡介

格裏戈裏?薩普諾夫(Grigory Sapunov)是Intento公司的聯合創始人兼首席技術官(CTO)。他是一位擁有超過20年從業經驗的軟件工程師,獲得人工智能博士學位,同時持有機器學習領域谷歌開發者專家(Google Developer Expert,GDE)認證。

目錄大綱

目    錄

第Ⅰ部分  入門基礎

第1章  何時以及為什麼使用JAX  3

1.1  使用JAX的理由  6

1.1.1  計算性能  7

1.1.2  函數式方法  9

1.1.3  JAX生態系統  10

1.2  JAX與NumPy的區別  11

1.2.1  JAX作為NumPy  12

1.2.2  可組合的變換  12

1.3  JAX與TensorFlow和PyTorch的區別  14

1.4  本章小結  16

第2章  你的第一個JAX程序  17

2.1  一個簡單的機器學習問題:手寫數字分類  18

2.2  JAX深度學習項目概覽  19

2.3  加載和準備數據集  20

2.4  在JAX中構建一個簡單的神經網絡  22

2.4.1  神經網絡初始化  24

2.4.2  神經網絡前向傳播  25

2.5  vmap:自動向量化計算以支持批處理  28

2.6  自動微分:如何在不手動計算導數的情況下計算梯度  30

2.6.1  損失函數  32

2.6.2  獲取梯度  33

2.6.3  梯度更新步驟  33

2.6.4  訓練循環  34

2.7  JIT:將代碼編譯為更快的版本  36

2.8  保存和部署模型  37

2.9  純函數和可組合的轉換:它們為什麼重要  39

2.10  練習  40

2.11  本章小結  40

第Ⅱ部分  JAX核心機制

第3章  數組操作  43

3.1  使用NumPy數組進行圖像處理  44

3.1.1  將圖像加載到NumPy數組中  45

3.1.2  對圖像執行基本預處理操作  48

3.1.3  向圖像添加噪聲  50

3.1.4  實現圖像濾波  51

3.1.5  將張量保存為圖像文件  55

3.2  JAX中的數組  57

3.2.1  切換到JAX類似NumPy的API  57

3.2.2  什麼是Array  58

3.2.3  與設備相關的操作  61

3.2.4  異步調度  65

3.2.5  在TPU上運行計算  66

3.3  與NumPy的區別  69

3.3.1  不可變性  70

3.3.2  類型  73

3.4  高級接口與低級接口:jax.numpy和jax.lax  76

3.4.1  控制流原語  77

3.4.2  類型提升  78

3.5  練習  79

3.6  本章小結  79

第4章  計算梯度  81

4.1  獲取導數的不同方法  82

4.1.1  手動求導  83

4.1.2  符號微分  84

4.1.3  數值微分  85

4.1.4  自動微分  87

4.2  使用自動微分計算梯度  89

4.2.1  在TensorFlow中使用梯度  91

4.2.2  在PyTorch中使用梯度  92

4.2.3  在JAX中使用梯度  92

4.2.4  高階導數  100

4.2.5  多變量情況  102

4.3  前向模式和反向模式自動微分  104

4.3.1  計算軌跡  105

4.3.2  前向模式和jvp()  106

4.3.3  反向模式和vjp()  110

4.3.4  深入探索  113

4.4  本章小結  114

第5章  編譯代碼  115

5.1  使用編譯  116

5.1.1  使用JIT編譯  117

5.1.2  純函數與編譯過程  123

5.2  JIT內部機制  125

5.2.1  Jaxpr:JAX程序的中間表示形式  125

5.2.2  XLA  134

5.2.3  使用AOT編譯  138

5.3  JIT的局限性  142

5.3.1  純函數與非純函數  142

5.3.2  精確數值  142

5.3.3  輸入參數值依賴的條件控制  142

5.3.4  編譯速度慢  142

5.3.5  類方法  144

5.3.6  簡單函數  145

5.4  練習  146

5.5  本章小結  146

第6章  向量化代碼  147

6.1  向量化函數的不同方法  148

6.1.1  樸素方法  149

6.1.2  手動向量化  151

6.1.3  自動向量化  151

6.1.4  性能比較  152

6.2  控制vmap()行為  154

6.2.1  控制映射的數組軸  154

6.2.2  控制輸出數組的軸  157

6.2.3  使用命名參數  158

6.2.4  使用裝飾器風格  160

6.2.5  使用集合操作  161

6.3  vmap()的實際應用案例  162

6.3.1  批數據處理  162

6.3.2  批量化神經網絡模型  164

6.3.3  每個樣本的梯度  165

6.3.4  向量化循環  166

6.4  本章小結  169

第7章  並行化計算  171

7.1  使用pmap()並行化計算  172

7.1.1  問題設置  172

7.1.2  像使用vmap一樣使用pmap  175

7.2  控制pmap()的行為  180

7.2.1  控制輸入和輸出的映射軸  181

7.2.2  使用命名軸和集合操作  186

7.3  數據並行的神經網絡訓練示例  193

7.3.1  準備數據和神經網絡結構  194

7.3.2  實現數據並行訓練  196

7.4  使用多主機配置  201

7.5  本章小結  206

第8章  使用張量切分  209

8.1  張量分片基礎  210

8.1.1  設備網格  212

8.1.2  位置分片  213

8.1.3  二維網格示例  213

8.1.4  使用復制  217

8.1.5  分片約束  219

8.1.6  命名切分  221

8.1.7  設備放置策略與錯誤  222

8.2  使用張量分片的多層感知機(MLP)  224

8.2.1  八路數據並行  224

8.2.2  四路數據並行,雙路張量並行  226

8.3  本章小結  229

第9章  JAX中的隨機數  231

9.1  生成隨機數據  232

9.1.1  載入數據集  233

9.1.2  生成隨機噪聲  235

9.1.3  執行隨機增強  239

9.2  與NumPy的區別  241

9.2.1  NumPy的工作原理  241

9.2.2  NumPy中的種子和狀態  243

9.2.3  JAX PRNG  246

9.2.4  JAX PRNG高級配置  252

9.3  在實際應用中生成隨機數  253

9.3.1  構建一個完整的數據增強管道  253

9.3.2  為神經網絡生成隨機初始化  255

9.4  本章小結  256

第10章  處理pytree  257

10.1  將復雜數據結構表示為pytree  258

10.2  處理pytree的函數  262

10.2.1  使用tree_map()  263

10.2.2  扁平化/還原pytree  265

10.2.3  使用tree_reduce()  267

10.2.4  轉置pytree  268

10.3  創建自定義pytree節點  271

10.4  本章小結  274

第Ⅲ部分   生態系統

第11章  高級神經網絡庫  277

11.1  使用MLP進行MNIST圖像分類  278

11.1.1  Flax中的MLP  278

11.1.2  Optax梯度變換庫  284

11.1.3  使用Flax訓練神經網絡  286

11.2  使用ResNet進行圖像分類  290

11.2.1  在Flax中管理狀態  290

11.2.2  使用Orbax保存和加載模型  296

11.3  使用Hugging Face生態系統  298

11.3.1  使用Hugging Face Model Hub中的預訓練模型  299

11.3.2  進一步探索:微調與再訓練  304

11.3.3  使用diffusers庫  306

11.4  本章小結  310

第12章  JAX生態系統的其他成員  313

12.1  深度學習生態系統  314

12.1.1  高層神經網絡庫  314

12.1.2  JAX中的大型語言模型(LLM)  315

12.1.3  工具庫  317

12.2  機器學習模塊  319

12.2.1  強化學習  319

12.2.2  其他機器學習庫  320

12.3  其他領域的JAX模塊  321

12.4  本章小結  322

附錄A  安裝JAX  325

附錄B  使用Google Colab  329

附錄C  使用Google Cloud TPU  331

附錄D  實驗性並行化  335