以下是PyTorch变量用法的简单示例,将v1和v2相乘的结果赋值给v3,其中里面的参数requires_grad的属性默认为False,若一个节点requires_grad被设置为True,那么所有依赖它的节点的requires_grad都为True,主要用于梯度的计算。
- #Variable(part of autograd package)
- #Variable (graph nodes) are thin wrappers around tensors and have dependency knowle
- #Variable enable backpropagation of gradients and automatic differentiations
- #Variable are set a 'volatile' flad during infrencing
-
-
- from torch.autograd import Variable
- v1 = Variable(torch.tensor([1.,2.,3.]), requires_grad=False)
- v2 = Variable(torch.tensor([4.,5.,6.]), requires_grad=True)
- v3 = v1*v2
-
-
- v3.data.numpy()
运行结果:
- #Variables remember what created them
- v3.grad_fn
运行结果:
Back Propagation
反向传播算法用于计算相对于输入权重和偏差的损失梯度,以在下一次优化迭代中更新权重并最终减少损失,PyTorch在分层定义对于变量的反向方法以执行反向传播方面非常智能。
以下是一个简单的反向传播计算方法,以sin(x)为例计算差分:
- #Backpropagation with example of sin(x)
- x=Variable(torch.Tensor(np.array([0.,1.,1.5,2.])*np.pi),requires_grad=True)
- y=torch.sin(x)
- x.grad
- y.backward(torch.Tensor([1.,1.,1.,1]))
-
-
- #Check gradient is indeed cox(x)
- if( (x.grad.data.int().numpy()==torch.cos(x).data.int().numpy()).all() ):
- print ("d(sin(x)/dx=cos(x))")
运行结果:
对于pytorch中的变量和梯度计算可参考下面这篇文章:
https://zhuanlan.zhihu.com/p/29904755
SLR: Simple Linear Regression
现在我们了解了基础知识,可以开始运用PyTorch 解决简单的机器学习问题——简单线性回归。我们将通过4个简单步骤完成:
第一步:
在步骤1中,我们创建一个由方程y = wx + b产生的人工数据集,并注入随机误差。请参阅以下示例:
- #Simple Liner Regression
- # Fit a line to the data. Y =w.x+b
- #Deterministic behavior
- np.random.seed(0)
- torch.manual_seed(0)
- #Step 1:Dataset
- w=2;b=3
- x=np.linspace(0,10,100)
- y=w*x+b+np.random.randn(100)*2
- xx=x.reshape(-1,1)
- yy=y.reshape(-1,1)
第二步:
在第2步中,我们使用forward函数定义一个简单的类LinearRegressionModel,使用torch.nn.Linear定义构造函数以对输入数据进行线性转换:
- #Step 2:Model
- class LinearRegressionModel(torch.nn.Module):
-
- def __init__(self,in_dimn,out_dimn):
- super(LinearRegressionModel,self).__init__()
- self.model=torch.nn.Linear(in_dimn,out_dimn)
-
- def forward(self,x):
- y_pred=self.model(x);
- return y_pred;
-
- model=LinearRegressionModel(in_dimn=1, out_dimn=1)
torch.nn.Linear参考网站:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html
第三步:
(编辑:源码网)
【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!
|