Pythonによる勾配降下法の実装
メニューを表示するにはスワイプしてください
勾配降下法は、最急降下の方向に進むことで関数を最小化するという、シンプルながら強力なアイデアに基づいています。
数学的なルールは次の通りです:
theta = theta - alpha * gradient(theta)
ここで:
theta:最適化するパラメータalpha:学習率(ステップサイズ)gradient(theta):thetaにおける関数の勾配
1. 関数とその導関数の定義
まずは単純な二次関数から始めます:
def f(theta):
return theta**2 # Function we want to minimize
その導関数(勾配)は次の通りです:
def gradient(theta):
return 2 * theta # Derivative: f'(theta) = 2*theta
f(theta):この関数の最小値となるthetaを求めます。gradient(theta):任意のthetaにおける傾きを示し、更新方向の決定に利用します。
2. 勾配降下法のパラメータ初期化
alpha = 0.3 # Learning rate
theta = 3.0 # Initial starting point
tolerance = 1e-5 # Convergence threshold
max_iterations = 20 # Maximum number of updates
alpha(学習率):各ステップの大きさを制御theta(初期値):降下の開始点tolerance:更新が非常に小さくなったときに停止max_iterations:無限ループを防ぐための最大回数
3. 勾配降下法の実行
for i in range(max_iterations):
grad = gradient(theta) # Compute gradient
new_theta = theta - alpha * grad # Update rule
if abs(new_theta - theta) < tolerance:
print("Converged!")
break
theta = new_theta
thetaにおける勾配の計算;- 勾配降下法の式を用いた
thetaの更新; - 更新量が十分小さくなったら停止(収束);
- 各ステップを出力して進捗を確認。
4. 勾配降下法の可視化
123456789101112131415161718192021222324252627282930313233343536373839import matplotlib.pyplot as plt import numpy as np def f(theta): return theta**2 # Function we want to minimize def gradient(theta): return 2 * theta # Derivative: f'(theta) = 2*theta alpha = 0.3 # Learning rate theta = 3.0 # Initial starting point tolerance = 1e-5 # Convergence threshold max_iterations = 20 # Maximum number of updates theta_values = [theta] # Track parameter values output_values = [f(theta)] # Track function values for i in range(max_iterations): grad = gradient(theta) # Compute gradient new_theta = theta - alpha * grad # Update rule if abs(new_theta - theta) < tolerance: break theta = new_theta theta_values.append(theta) output_values.append(f(theta)) # Prepare data for plotting the full function curve theta_range = np.linspace(-4, 4, 100) output_range = f(theta_range) # Plot plt.plot(theta_range, output_range, label="f(θ) = θ²", color='black') plt.scatter(theta_values, output_values, color='red', label="Gradient Descent Steps") plt.title("Gradient Descent Visualization") plt.xlabel("θ") plt.ylabel("f(θ)") plt.legend() plt.grid(True) plt.show()
このプロットは以下を示しています:
- 関数曲線 f(θ)=θ2
- 収束までの各勾配降下ステップを表す赤い点
すべて明確でしたか?
フィードバックありがとうございます!
セクション 3. 章 10
AIに質問する
AIに質問する
何でも質問するか、提案された質問の1つを試してチャットを始めてください
セクション 3. 章 10