torch读取数据

library(torch)
library(palmerpenguins)
library(dplyr)

# 以摘要形式展示数据集的结构和内容
penguins %>% glimpse()
展开/折叠结果
Rows: 344
Columns: 8
$ species            Adelie, Adelie, Adelie, Adelie, Adelie…
$ island             Torgersen, Torgersen, Torgersen, Torge…
$ bill_length_mm     39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9…
$ bill_depth_mm      18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8…
$ flipper_length_mm  181, 186, 195, NA, 193, 190, 181, 195,…
$ body_mass_g        3750, 3800, 3250, NA, 3450, 3650, 3625…
$ sex                male, female, female, NA, female, male…
$ year               2007, 2007, 2007, 2007, 2007, 2007, 20…
# 数据框转换为张量
penguins_dataset <- dataset(
  name = "penguins_dataset()",
  # 初始化
  initialize = function(df) {
    # 去除NA
    df <- na.omit(df)
    # 取数据框第3到第6列,转换为矩阵,然后进一步转换为PyTorch张量
    self$x <- as.matrix(df[, 3:6]) %>% torch_tensor()
    # 将species列(物种信息)转换为数值型(类别编码),再转为PyTorch长整型张量
    self$y <- torch_tensor(
      as.numeric(df$species)
    )$to(torch_long())
  },
  # 根据索引 i 返回一个单独的数据样本
  .getitem = function(i) {
    list(x = self$x[i, ], y = self$y[i])
  },
  # 返回特征数据 self$x 的行数,即数据集的样本总数
  .length = function() {
    dim(self$x)[1]
  }
)

# 张量赋值给 ds 函数
ds <- penguins_dataset(penguins)
# 查看张量的长度
length(ds)
# 查看第 1 行数据
ds[1]
展开/折叠结果
[1] 333
$x

torch_tensor
   39.1000
   18.7000
  181.0000
 3750.0000
[ CPUFloatType{4} ]

$y
torch_tensor
1
[ CPULongType{} ]

# 创建随机由三个随机张量10个组成的数据集
three <- tensor_dataset(
  # 从正态分布中生成随机数,张量的维数是10
  # 即包含10个元素的1维向量
  torch_randn(10), torch_randn(10), torch_randn(10)
)
# 从数据集中获取第一个样本
three[1]
展开/折叠结果
[[1]]
torch_tensor
 1.3905
[ CPUFloatType{1} ]

[[2]]
torch_tensor
 2.1180
[ CPUFloatType{1} ]

[[3]]
torch_tensor
-0.3222
[ CPUFloatType{1} ]
# 去除带有 NA 的行
penguins <- na.omit(penguins)
# 构建函数,3:6 为矩阵张量;
ds <- tensor_dataset(
  torch_tensor(as.matrix(penguins[, 3:6])),
  torch_tensor(
    # species转化为数值型再转化为长整数型再转化为张量
    as.numeric(penguins$species)
  )$to(torch_long())
)
# 取数据集中第一条数据,两个部分:
# (1)特征张量(相当于x)
# (2)标签(相当于y)
ds[1]
展开/折叠结果
[[1]]
torch_tensor
   39.1000    18.7000   181.0000  3750.0000
[ CPUFloatType{1,4} ]

[[2]]
torch_tensor
 1
[ CPULongType{1} ]

评论

发表评论

了解 数据控|突破是我们的每一步 的更多信息

立即订阅以继续阅读并访问完整档案。

继续阅读