用python来实现一个RBF径向基神经网络
【摘要】 使用python来实现一个RBF神经网络,代码如下
使用python来实现一个RBF神经网络,代码如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
class RBFNetwork:
def __init__(self, num_centers, spread=1.0):
"""
初始化 RBF 网络
"""
self.num_centers = num_centers
self.spread = spread
self.centers = None
self.weights = None
def _gaussian_rbf(self, x, center):
dist_sq = np.sum((x - center) ** 2, axis=1) # 计算欧氏距离的平方
return np.exp(-dist_sq / (2 * self.spread ** 2)) # 应用高斯函数
def fit(self, X, y): # 训练 RBF 网络
N, D = X.shape
kmeans = KMeans(n_clusters=self.num_centers, random_state=42, n_init=10)
kmeans.fit(X)
self.centers = kmeans.cluster_centers_
G = np.zeros((N, self.num_centers))
for j in range(self.num_centers):
G[:, j] = self._gaussian_rbf(X, self.centers[j])
self.weights = np.linalg.pinv(G).dot(y)
def predict(self, X): # 预测新数据
N = X.shape
G = np.zeros((N, self.num_centers))
for j in range(self.num_centers):
G[:, j] = self._gaussian_rbf(X, self.centers[j])
return G.dot(self.weights)
# --- 函数拟合实例 ---
np.random.seed(42)
X_train = np.linspace(-5, 5, 100).reshape(-1, 1)
y_train = np.sin(X_train).ravel() + 0.5 * np.cos(3 * X_train).ravel() + 0.1 * np.random.randn(100)
# 训练 RBF 网络
num_hidden_neurons = 20 # 隐含层节点数
spread_param = 1.5 # 宽度参数,控制平滑度
rbf_net = RBFNetwork(num_centers=num_hidden_neurons, spread=spread_param)
rbf_net.fit(X_train, y_train)
# 预测
X_test = np.linspace(-5, 5, 200).reshape(-1, 1)
y_pred = rbf_net.predict(X_test)
y_true_test = np.sin(X_test).ravel() + 0.5 * np.cos(3 * X_test).ravel()
# 结果
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, color='blue', label='Training Data (Noisy)', alpha=0.6)
plt.plot(X_test, y_true_test, 'g-', label='True Function', linewidth=2)
plt.plot(X_test, y_pred, 'r--', label='RBF Prediction', linewidth=2)
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)
plt.show()
print("训练完成。RBF 网络成功逼近了非线性函数。")
python运行上面的代码就可以了。
【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)