JAX是一個用于高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計。本書詳解JAX框架深度學習的相關知識,配套示例源碼、PPT課件、數據集和開發環境。 本書共分為13章,內容包括JAX從零開始,一學就會的線性回歸、多層感知機與自動微分器,深度學習的理論基礎,XLA與JAX一般特性,JAX的高級特性,JAX的一些細節,JAX中的卷積,JAX與TensorFlow的比較與交互,遵循JAX函數基本規則下的自定義函數,JAX中的高級包。后給出3個實戰案例:使用ResNet完成CIFAR100數據集分類,有趣的詞嵌入,生成對抗網絡(GAN)。 本書適合JAX框架初學者、深度學習初學者以及深度學習從業人員,也適合作為高等院校和培訓機構人工智能相關專業的師生教學參考書。
JAX是一個用于高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計。本書詳解JAX框架深度學習的相關知識,并剖析3個實戰案例: 使用ResNet完成CIFAR100數據集分類、有趣的詞嵌入與生成對抗網絡。本書配套示例源碼、PPT課件、數據集、開發環境和答疑服務。
深度學習和人工智能引領了一個新的研究和發展方向,同時正在改變人類固有的處理問題的思維。現在各個領域都處于運用深度學習技術進行重大技術突破的階段,與此同時,深度學習本身也展現出巨大的發展空間。
JAX是一個用于高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計,其包含豐富的數值計算與科學計算函數,能夠很好地滿足用戶的計算需求,特別是其基于GPU或者其他硬件加速器的能力,能夠幫助我們在現有的條件下極大地加速深度學習模型的訓練與預測。
JAX繼承了Python簡單易用的優點,給使用者提供了一個便于入門,能夠提高的深度學習實現方案。JAX在代碼結構上采用面向對象方法編寫,完全模塊化,并具有可擴展性,其運行機制和說明文檔都將用戶體驗和使用難度納入考慮范圍,降低了復雜算法的實現難度。JAX的計算核心使用的是自動微分,可以支持自動模式反向傳播和正向傳播,且二者可以任意組合成任何順序。
本書由淺到深地向讀者介紹JAX框架相關的知識,重要內容均結合代碼進行實戰講解,讀者通過這些實例可以深入掌握JAX程序設計的內容,并能對深度學習有進一步的了解。
本書特色
版本新,易入門
本書詳細介紹JAX最新版本的安裝和使用,包括CPU版本以及GPU版本。
作者經驗豐富,代碼編寫細膩
作者是長期奮戰在科研和工業界的一線算法設計和程序編寫人員,實戰經驗豐富,對代碼中可能會出現的各種問題和坑有豐富的處理經驗,使得讀者能夠少走很多彎路。
理論扎實,深入淺出
在代碼設計的基礎上,本書深入淺出地介紹深度學習需要掌握的一些基本理論知識,并通過大量的公式與圖示對理論做介紹。
對比多種應用方案,實戰案例豐富
本書給出了大量的實例,同時提供多個實現同類功能的解決方案,覆蓋使用JAX進行深度學習開發中常用的知識。
本書內容
第1章 JAX從零開始
本章介紹JAX應用于深度學習的基本理念、基礎,并通過一個真實的深度學習例子向讀者展示深度學習的一般訓練步驟。本章是全書的基礎,讀者需要先根據本章內容搭建JAX開發環境,并下載合適的IDE。
第2章 一學就會的線性回歸、多層感知機與自動微分器
本章以深度學習中最常見的線性回歸和多層感知機的程序設計為基礎,循序漸進地介紹JAX進行深度學習程序設計的基本方法和步驟。
第3章 深度學習的理論基礎
本章主要介紹深度學習的理論基礎,從BP神經網絡開始,介紹神經網絡兩個基礎算法,并著重介紹反向傳播算法的完整過程和理論,最后通過編寫基本Python的方式實現一個完整的反饋神經網絡。
第4章 XLA與JAX一般特性
本章主要介紹JAX的一些基礎特性,例如XLA、自動微分等。讀者需要了解的是XLA是如何工作的,它能給JAX帶來什么。
第5章 JAX的高級特性
本章是基于上一章的基礎比較JAX與NumPy,重點解釋JAX在實踐中的一些程序設計和編寫的規范要求,并對其中的循環函數做一個詳盡而細致的說明。
第6章 JAX的一些細節
本章主要介紹JAX在設計性能較優的程序時的細節問題,并介紹JAX內部一整套結構體保存方法和對模型參數的控制,這些都是為我們能編寫出更為強大的深度學習代碼打下基礎。
第7章 JAX中的卷積
卷積可以說是深度學習中使用最為廣泛的計算部件,本章主要介紹卷積的基礎知識以及相關用法,并通過一個經典的卷積神經網絡VGG模型,講解卷積的應用和JAX程序設計的一些基本內容。
第8章 JAX與TensorFlow的比較與交互
本章主要介紹在一些需要的情況下使用已有的TensorFlow組件的一些方法。作為深度學習經典框架,TensorFlow有很多值得JAX參考和利用的內容。
第9章 遵循JAX函數基本規則下的自定義函數
本章介紹JAX創建自定義函數的基本規則,并對其中涉及的一些細節問題進行詳細講解。期望讀者在了解和掌握如何利用和遵循這些基本規則后去創建既滿足需求又能夠符合JAX函數規則的自定義函數。
第10章 JAX中的高級包
本章詳細介紹JAX中高級程序設計子包,特別是2個非常重要的模塊jax.experimental和jax.nn。這兩個包目前仍處于測試階段,但是包含了建立深度學習模型所必需的一些基本函數。
第11章 JAX實戰使用ResNet完成CIFAR100數據集分類
本章主要介紹在神經網絡領域具有里程碑意義的模型ResNet。它改變了人們僅僅依靠堆積神經網絡層來獲取更高性能的做法,在一定程度上解決了梯度消失和梯度爆炸的問題。這是一項跨時代的發明。本章以手把手的方式向讀者介紹ResNet模型的編寫和架構方法。
第12章 JAX實戰有趣的詞嵌入
本章介紹JAX于自然語言處理的應用,通過一個完整的例子向讀者介紹自然語言處理所需要的全部內容,一步步地教會讀者使用不同的架構和維度進行文本分類的方法。
第13章 JAX實戰生成對抗網絡(GAN)
本章介紹使用JAX完成生成對抗網絡模型的設計,講解如何利用JAX完成更為復雜的深度學習模型設計,掌握JAX程序設計的技巧。同時也期望通過本章能夠幫助讀者全面復習本書所涉及的JAX的深度學習程序設計內容。
源碼下載與技術支持
本書配套源碼、PPT課件、數據集、開發環境、配圖文件和答疑服務,需要使用微信掃描右邊二維碼下載,可按頁面提示,把鏈接轉發到自己的郵箱中下載。如果下載有問題或者閱讀中發現問題,請聯系booksaga@163.com,郵件主題為谷歌JAX深度學習從零開始學。
本書讀者
人工智能入門讀者
深度學習入門讀者
機器學習入門讀者
高等院校人工智能專業的師生
專業培訓機構的師生
其他對智能化、自動化感興趣的開發者
技術支持、勘誤和鳴謝
由于作者的水平有限,加上JAX框架的演進較快,書中難免存在疏漏之處,懇請讀者來信批評指正。本書的順利出版,首先要感謝家人的理解和支持,他們給予我莫大的動力,讓我的努力更加有意義。此外特別感謝出版社的編輯們,感謝他們在本書編寫過程中給予的無私指導。
編 者
2022年4月
王曉華,計算機專業講師,研究方向為云計算、大數據與人工智能。著有《Spark MLlib機器學習實踐》《TensorFlow深度學習應用實踐》《OpenCV TensorFlow深度學習與計算機視覺實戰》《TensorFlow知識圖譜實戰》《TensorFlow人臉識別實戰》《TensorFlow語音識別實戰》《TensorFlow 2.0卷積神經網絡實戰》《Keras實戰:基于TensorFlow2.2的深度學習實踐》《TensorFlow深度學習從零開始學》《深度學習的數學原理與實現》等圖書。
第 1 章 JAX從零開始 1
1.1 JAX來了 1
1.1.1 JAX是什么 1
1.1.2 為什么是JAX 2
1.2 JAX的安裝與使用 3
1.2.1 Windows Subsystem for Linux的安裝 3
1.2.2 JAX的安裝和驗證 7
1.2.3 PyCharm的下載與安裝 8
1.2.4 使用PyCharm和JAX 9
1.2.5 JAX的Python代碼小練習:計算SeLU函數 11
1.3 JAX實戰MNIST手寫體的識別 12
1.3.1 第一步:準備數據集 12
1.3.2 第二步:模型的設計 13
1.3.3 第三步:模型的訓練 13
1.4 本章小結 15
第2章 一學就會的線性回歸、多層感知機與自動微分器 16
2.1 多層感知機 16
2.1.1 全連接層多層感知機的隱藏層 16
2.1.2 使用JAX實現一個全連接層 17
2.1.3 更多功能的全連接函數 19
2.2 JAX實戰鳶尾花分類 22
2.2.1 鳶尾花數據準備與分析 23
2.2.2 模型分析采用線性回歸實戰鳶尾花分類 24
2.2.3 基于JAX的線性回歸模型的編寫 25
2.2.4 多層感知機與神經網絡 27
2.2.5 基于JAX的激活函數、softmax函數與交叉熵函數 29
2.2.6 基于多層感知機的鳶尾花分類實戰 31
2.3 自動微分器 35
2.3.1 什么是微分器 36
2.3.2 JAX中的自動微分 37
2.4 本章小結 38
第3章 深度學習的理論基礎 39
3.1 BP神經網絡簡介 39
3.2 BP神經網絡兩個基礎算法詳解 42
3.2.1 最小二乘法詳解 43
3.2.2 道士下山的故事梯度下降算法 44
3.2.3 最小二乘法的梯度下降算法以及JAX實現 46
3.3 反饋神經網絡反向傳播算法介紹 52
3.3.1 深度學習基礎 52
3.3.2 鏈式求導法則 53
3.3.3 反饋神經網絡原理與公式推導 54
3.3.4 反饋神經網絡原理的激活函數 59
3.3.5 反饋神經網絡原理的Python實現 60
3.4 本章小結 64
第4章 XLA與JAX一般特性 65
4.1 JAX與XLA 65
4.1.1 XLA如何運行 65
4.1.2 XLA如何工作 67
4.2 JAX一般特性 67
4.2.1 利用JIT加快程序運行 67
4.2.2 自動微分器grad函數 68
4.2.3 自動向量化映射vmap函數 70
4.3 本章小結 71
第5章 JAX的高級特性 72
5.1 JAX與NumPy 72
5.1.1 像NumPy一樣運行的JAX 72
5.1.2 JAX的底層實現lax 74
5.1.3 并行化的JIT機制與不適合使用JIT的情景 75
5.1.4 JIT的參數詳解 77
5.2 JAX程序的編寫規范要求 78
5.2.1 JAX函數必須要為純函數 79
5.2.2 JAX中數組的規范操作 80
5.2.3 JIT中的控制分支 83
5.2.4 JAX中的if、while、for、scan函數 85
5.3 本章小結 89
第6章 JAX的一些細節 90
6.1 JAX中的數值計算 90
6.1.1 JAX中的grad函數使用細節 90
6.1.2 不要編寫帶有副作用的代碼JAX與NumPy的差異 93
6.1.3 一個簡單的線性回歸方程擬合 94
6.2 JAX中的性能提高 98
6.2.1 JIT的轉換過程 98
6.2.2 JIT無法對非確定參數追蹤 100
6.2.3 理解JAX中的預編譯與緩存 102
6.3 JAX中的函數自動打包器vmap 102
6.3.1 剝洋蔥對數據的手工打包 102
6.3.2 剝甘藍JAX中的自動向量化函數vmap 104
6.3.3 JAX中高階導數的處理 105
6.4 JAX中的結構體保存方法Pytrees 106
6.4.1 Pytrees是什么 106
6.4.2 常見的pytree函數 107
6.4.3 深度學習模型參數的控制(線性模型) 108
6.4.4 深度學習模型參數的控制(非線性模型) 113
6.4.5 自定義的Pytree節點 113
6.4.6 JAX數值計算的運行機制 115
6.5 本章小結 117
第7章 JAX中的卷積 118
7.1 什么是卷積 118
7.1.1 卷積運算 119
7.1.2 JAX中的一維卷積與多維卷積的計算 120
7.1.3 JAX.lax中的一般卷積的計算與表示 122
7.2 JAX實戰基于VGG架構的MNIST數據集分類 124
7.2.1 深度學習Visual Geometry Group(VGG)架構 124
7.2.2 VGG中使用的組件介紹與實現 126
7.2.3 基于VGG6的MNIST數據集分類實戰 129
7.3 本章小結 133
第8章 JAX與TensorFlow的比較與交互 134
8.1 基于TensorFlow的MNIST分類 134
8.2 TensorFlow與JAX的交互 137
8.2.1 基于JAX的TensorFlow Datasets數據集分類實戰 137
8.2.2 TensorFlow Datasets數據集庫簡介 141
8.3 本章小結 145
第9章 遵循JAX函數基本規則下的自定義函數 146
9.1 JAX函數的基本規則 146
9.1.1 使用已有的原語 146
9.1.2 自定義的JVP以及反向VJP 147
9.1.3 進階jax.custom_jvp和jax.custom_vjp函數用法 150
9.2 Jaxpr解釋器的使用 153
9.2.1 Jaxpr tracer 153
9.2.2 自定義的可以被Jaxpr跟蹤的函數 155
9.3 JAX維度名稱的使用 157
9.3.1 JAX的維度名稱 157
9.3.2 自定義JAX中的向量Tensor 158
9.4 本章小結 159
第10章 JAX中的高級包 160
10.1 JAX中的包 160
10.1.1 jax.numpy的使用 161
10.1.2 jax.nn的使用 162
10.2 jax.experimental包和jax.example_libraries的使用 163
10.2.1 jax.experimental.sparse的使用 163
10.2.2 jax.experimental.optimizers模塊的使用 166
10.2.3 jax.experimental.stax的使用 168
10.3 本章小結 168
第11章 JAX實戰使用ResNet完成CIFAR100數據集分類 169
11.1 ResNet基礎原理與程序設計基礎 169
11.1.1 ResNet誕生的背景 170
11.1.2 使用JAX中實現的部件不要重復造輪子 173
11.1.3 一些stax模塊中特有的類 175
11.2 ResNet實戰CIFAR100數據集分類 176
11.2.1 CIFAR100數據集簡介 176
11.2.2 ResNet殘差模塊的實現 179
11.2.3 ResNet網絡的實現 181
11.2.4 使用ResNet對CIFAR100數據集進行分類 182
11.3 本章小結 184
第12章 JAX實戰有趣的詞嵌入 185
12.1 文本數據處理 185
12.1.1 數據集和數據清洗 185
12.1.2 停用詞的使用 188
12.1.3 詞向量訓練模型word2vec的使用 190
12.1.4 文本主題的提取:基于TF-IDF 193
12.1.5 文本主題的提取:基于TextRank 197
12.2 更多的詞嵌入方法FastText和預訓練詞向量 200
12.2.1 FastText的原理與基礎算法 201
12.2.2 FastText訓練以及與JAX的協同使用 202
12.2.3 使用其他預訓練參數嵌入矩陣(中文) 204
12.3 針對文本的卷積神經網絡模型字符卷積 205
12.3.1 字符(非單詞)文本的處理 206
12.3.2 卷積神經網絡文本分類模型的實現conv1d(一維卷積) 213
12.4 針對文本的卷積神經網絡模型詞卷積 216
12.4.1 單詞的文本處理 216
12.4.2 卷積神經網絡文本分類模型的實現 218
12.5 使用卷積對文本分類的補充內容 219
12.5.1 中文的文本處理 219
12.5.2 其他細節 222
12.6 本章小結 222
第13章 JAX實戰生成對抗網絡(GAN) 223
13.1 GAN的工作原理詳解 223
13.1.1 生成器與判別器共同構成了一個GAN 224
13.1.2 GAN是怎么工作的 225
13.2 GAN的數學原理詳解 225
13.2.1 GAN的損失函數 226
13.2.2 生成器的產生分布的數學原理相對熵簡介 226
13.3 JAX實戰GAN網絡 227
13.3.1 生成對抗網絡GAN的實現 228
13.3.2 GAN的應用前景 232
13.4 本章小結 235
附錄 Windows 11安裝GPU版本的JAX 236