Logistic算法实现芯片测试
【摘要】 通过对于芯片测试的结果,设计Logistic代码进行测试
import functools
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn
from torch import *
from torch.autograd import Variable
import seaborn as sns
def transfrom(degree,x):
if x.ndim==1:
x=x[:,None]
x_t=x.T
features=[np.ones(len(x))]
for degree in range(1,degree+1):
for items in itertools.combinations_with_replacement(x_t,degree):
features.append(functools.reduce(lambda x,y:x*y,items))
return np.asarray(features).T
class LR(torch.nn.Module):
def __init__(self):
super(LR,self).__init__()
self.linear=torch.nn.Linear(6,1)
def forward(self, x):
y=torch.nn.Sigmoid()(self.linear(x))
return y
data=pd.read_csv('./ex2data2.txt',names=['feature1','feature2','accepted'])
x_data,T=transfrom(2,data.iloc[:,:-1].values),data.iloc[:,-1].values
X=Variable(torch.as_tensor(torch.from_numpy(x_data.astype(float)),dtype=torch.float32))
t=Variable(torch.as_tensor(torch.from_numpy(T.astype(float)),dtype=torch.float32))
model=LR()
criterion=torch.nn.BCELoss()
optimizer=torch.optim.SGD(model.parameters(),lr=1e-3,momentum=0.9)
cnt,Loss=list(),list()
for epoch in range(10000):
model.train()
optimizer.zero_grad()
y=model(X).squeeze(-1)
loss=criterion(y,t)
Loss.append(loss.data.item())
cnt.append(epoch+1)
loss.backward()
optimizer.step()
th0=model.state_dict()['linear.bias'].numpy()[0]
th=model.state_dict()['linear.weight'].numpy()[0]
x=np.linspace(-1,1.5,50)
xx,yy=np.meshgrid(x,x)
z=transfrom(2,np.vstack((xx.ravel(),yy.ravel())).T)
z= z @ th + th0
z=z.reshape(xx.shape)
sns.set(context='notebook',style='white',font_scale=1.5)
sns.lmplot('feature1','feature2',hue='accepted',data=data,height=6,fit_reg=False,scatter_kws={'s':50}, palette="Dark2_r")
plt.title("Raw data from chip testing")
plt.show()
sns.set(context='notebook',style='white',font_scale=1.5)
sns.lmplot('feature1','feature2',hue='accepted',data=data,height=6,fit_reg=False,scatter_kws={'s':50},palette="Dark2_r")
plt.contour(xx,yy,z,0)
plt.title("Logistic regression for chip testing")
plt.show()
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)