博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
机器学习笔记(6):多类逻辑回归-使用gluon
阅读量:7124 次
发布时间:2019-06-28

本文共 2905 字,大约阅读时间需要 9 分钟。

演示了纯手动添加隐藏层,这次使用gluon让代码更精减,代码来自:

from mxnet import gluonfrom mxnet import ndarray as ndimport matplotlib.pyplot as pltimport mxnet as mxfrom mxnet import autograd  def transform(data, label):    return data.astype('float32')/255, label.astype('float32')  mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)  def show_images(images):    n = images.shape[0]    _, figs = plt.subplots(1, n, figsize=(15, 15))    for i in range(n):        figs[i].imshow(images[i].reshape((28, 28)).asnumpy())        figs[i].axes.get_xaxis().set_visible(False)        figs[i].axes.get_yaxis().set_visible(False)    plt.show()def get_text_labels(label):    text_labels = [        'T 恤', '长 裤', '套头衫', '裙 子', '外 套',        '凉 鞋', '衬 衣', '运动鞋', '包 包', '短 靴'    ]    return [text_labels[int(i)] for i in label]  data, label = mnist_train[0:10]  print('example shape: ', data.shape, 'label:', label)show_images(data)print(get_text_labels(label))  batch_size = 256train_data = gluon.data.DataLoader(mnist_train, batch_size, shuffle=True)test_data = gluon.data.DataLoader(mnist_test, batch_size, shuffle=False)  #计算模型net = gluon.nn.Sequential()with net.name_scope():    net.add(gluon.nn.Flatten())    net.add(gluon.nn.Dense(256, activation="relu"))    net.add(gluon.nn.Dense(10))net.initialize()  softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()#定义训练器trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.5}) def accuracy(output, label):    return nd.mean(output.argmax(axis=1) == label).asscalar()  def _get_batch(batch):    if isinstance(batch, mx.io.DataBatch):        data = batch.data[0]        label = batch.label[0]    else:        data, label = batch    return data, label  def evaluate_accuracy(data_iterator, net):    acc = 0.    if isinstance(data_iterator, mx.io.MXDataIter):        data_iterator.reset()    for i, batch in enumerate(data_iterator):        data, label = _get_batch(batch)        output = net(data)        acc += accuracy(output, label)    return acc / (i+1)  for epoch in range(5):    train_loss = 0.    train_acc = 0.    for data, label in train_data:        with autograd.record():            output = net(data)            loss = softmax_cross_entropy(output, label)        loss.backward()        trainer.step(batch_size) #使用训练器,向"前"走一步        train_loss += nd.mean(loss).asscalar()        train_acc += accuracy(output, label)    test_acc = evaluate_accuracy(test_data, net)    print("Epoch %d. Loss: %f, Train acc %f, Test acc %f" % (        epoch, train_loss/len(train_data), train_acc/len(train_data), test_acc))data, label = mnist_test[0:10]show_images(data)print('true labels')print(get_text_labels(label))  predicted_labels = net(data).argmax(axis=1)print('predicted labels')print(get_text_labels(predicted_labels.asnumpy()))

 有变化的地方,已经加上了注释。运行效果,跟一篇完全相同,就不重复贴图了

转载地址:http://pfael.baihongyu.com/

你可能感兴趣的文章
【redis】redis五大类 用法 【转载:https://www.cnblogs.com/yanan7890/p/6617305.html】
查看>>
【IntelliJ IDEA】idea设置UTF-8的位置
查看>>
人生就是一场修炼
查看>>
PHP 最佳实践(译)——PHP 容易混淆技术的实用指南
查看>>
软工实践-结对作业2
查看>>
OI中的一些模板
查看>>
10.线程池_线程调度
查看>>
C#_delegate - combine function
查看>>
收藏几个HTML5游戏引擎
查看>>
c#利用循环将类实例化为对象
查看>>
win7定时任务
查看>>
input reset 重置时间
查看>>
supervisord
查看>>
Java IO2
查看>>
对抽象函数abstract的运用
查看>>
C\C++编程中:相对路径+绝对路径
查看>>
Farewell to emacs
查看>>
leetcode之Palindrome Partitioning
查看>>
分布式进程
查看>>
第二部分 python基础 day09 python安装与初识
查看>>