Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
^.*\.Rproj$
^\.Rproj\.user$
^\.claude$
^README\.Rmd$
^\.github$
^Makefile$
1 change: 1 addition & 0 deletions .github/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.html
51 changes: 51 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
branches: [main, master]
pull_request:

name: R-CMD-check.yaml

permissions: read-all

jobs:
R-CMD-check:
runs-on: ${{ matrix.config.os }}

name: ${{ matrix.config.os }} (${{ matrix.config.r }})

strategy:
fail-fast: false
matrix:
config:
- {os: macos-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
- {os: ubuntu-latest, r: 'oldrel-1'}

env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
R_KEEP_PKG_SOURCE: yes

steps:
- uses: actions/checkout@v4

- uses: r-lib/actions/setup-pandoc@v2

- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}
http-user-agent: ${{ matrix.config.http-user-agent }}
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
with:
upload-snapshots: true
build_args: 'c("--no-manual","--compact-vignettes=gs+qpdf")'
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@
.RData
.Ruserdata
*.Rproj

/.quarto/
scratch.R
README.html
29 changes: 16 additions & 13 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
Package: survnet
Type: Package
Title: Artificial neural networks for survival analysis
Version: 0.0.5
Date: 2018-11-12
Author: Marvin N. Wright
Maintainer: Marvin N. Wright <cran@wrig.de>
Description: Artificial neural networks for survival analysis
Package: survnet
Title: Artificial Neural Networks for Survival Analysis
Version: 0.0.7
Authors@R:
person(c("Marvin", "N."), "Wright", , "cran@wrig.de", role = c("aut", "cre"))
Description: Fits artificial neural networks for survival analysis,
including competing risks, using 'keras3'.
License: MIT + file LICENSE
Depends:
survival,
keras
keras3,
survival
Imports:
magrittr
Suggests:
testthat
RoxygenNote: 6.1.0
magrittr,
reticulate
Suggests:
testthat (>= 3.0.0)
Encoding: UTF-8
RoxygenNote: 7.3.3
Config/testthat/edition: 3
67 changes: 67 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
PKGNAME = `sed -n "s/Package: *\([^ ]*\)/\1/p" DESCRIPTION`
PKGVERS = `sed -n "s/Version: *\([^ ]*\)/\1/p" DESCRIPTION`

all: format doc README.md install check

.PHONY: format
format:
@echo "Formatting with $(shell air --version)"
@air format .

.PHONY: doc
doc: README.md
Rscript -e "usethis::use_tidy_description()"
Rscript -e "devtools::document()"

.PHONY: build
build:
Rscript -e "devtools::build()"

.PHONY: vignettes
vignettes:
Rscript -e "devtools::build_vignettes()"

.PHONY: install
install: doc
Rscript -e "pak::local_install()"

.PHONY: deps
deps:
Rscript -e "pak::local_install_dev_deps()"

# devtools::check() automatically redocuments, so no need to add doc here
.PHONY: check
check:
Rscript -e "devtools::check()"

.PHONY: check-remote
check-remote:
Rscript -e "devtools::check(remote = TRUE)"

.PHONY: test-summary
test-summary:
Rscript -e "devtools::test(reporter = 'summary')"

.PHONY: test-slow
test-slow:
Rscript -e "devtools::test(reporter = 'slow')"

.PHONY: test
test:
Rscript -e "devtools::test()"

.PHONY: coverage
coverage:
Rscript -e "covr::report(covr::package_coverage(\".\"), file = \"coverage.html\")"

.PHONY: site
site:
Rscript -e "pkgdown::build_site()"

README.md: README.Rmd
Rscript -e "rmarkdown::render('README.Rmd')"
rm README.html

clean:
rm -r docs
rm coverage.html
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ S3method(predict,survnet)
export(convert_surv_cens)
export(loss_cif_loglik)
export(survnet)
import(keras)
import(keras3)
import(survival)
importFrom(magrittr,freduce)
importFrom(stats,predict)
11 changes: 11 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# survnet 0.0.7

* Switched from deprecated `keras` package to `keras3`.
* Use `py_require("tensorflow")` on package load to ensure TensorFlow is available.
* Updated loss function to use `keras3` ops (`op_sum`, `op_log`, `op_clip`) instead of legacy `backend()` API.
* Updated `optimizer_rmsprop()` to use `learning_rate` argument (replacing deprecated `lr`).
* Updated `regularizer_l2()` to use `l2` argument (replacing deprecated `l`).
* Removed `CUDNN_LSTM` and `CUDNN_GRU` RNN types (regular LSTM/GRU auto-use CuDNN in modern TensorFlow).
* Wrapped examples in `\donttest{}` and added test skips for environments without TensorFlow.
* Fixed DESCRIPTION metadata (Authors@R, title case, encoding, complete description).
* Used `\doi{}` instead of `\url{}` for DOI references in documentation.
18 changes: 9 additions & 9 deletions R/loss_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@
#' @export
#' @references
#' \itemize{
#' \item Jeong, J. & Fine, J. (2006). Direct parametric inference for the cumulative incidence function. J R Stat Soc Ser C Appl Stat 55:187-200. \url{https://doi.org/10.1111/j.1467-9876.2006.00532.x}.
#' \item Jeong, J. & Fine, J. (2006). Direct parametric inference for the cumulative incidence function. J R Stat Soc Ser C Appl Stat 55:187-200. \doi{10.1111/j.1467-9876.2006.00532.x}.
#' \item Lee, C., Zame, W.R., Yoon, J. & van der Shaar, M. (2018). DeepHit: A deep learning approach to survival analysis with competing risks. AAAI 2018. \url{http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit}.
#' }
loss_cif_loglik <- function(num_intervals, num_causes = 1){
function(y_true, y_pred) {
K <- backend()
eps <- config_epsilon()

# Survival and event indicators
S <- y_true[, 1:(num_causes * num_intervals)] # Survival
E <- y_true[, (num_causes * num_intervals + 1):(2 * num_causes * num_intervals)] # Events

# Likelihood part for uncensored and censored observations (0 for censored)
uncens <- K$sum(E * K$log(K$clip(y_pred, K$epsilon(), NULL)), axis = -1L)
delta <- 1 - K$sum(E, axis = -1L)
cens <- delta * K$log(K$clip(1 - K$sum(S * y_pred, axis = -1L), K$epsilon(), NULL))
# Return negative log-likelihood
uncens <- op_sum(E * op_log(op_clip(y_pred, eps, 1)), axis = -1L)
delta <- 1 - op_sum(E, axis = -1L)
cens <- delta * op_log(op_clip(1 - op_sum(S * y_pred, axis = -1L), eps, 1))

# Return negative log-likelihood
-(uncens + cens)
}
}
32 changes: 14 additions & 18 deletions R/survnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#' @param validation_split Fraction in [0,1] of the training data to be used as validation data.
#' @param loss Loss function.
#' @param activation Activation function.
#' @param rnn_type Type of RNN layers. Either \code{"LSTM"} (default), \code{"GRU"}, \code{"CUDNN_LSTM"} or \code{"CUDNN_GRU"}.
#' @param rnn_type Type of RNN layers. Either \code{"LSTM"} (default) or \code{"GRU"}.
#' @param skip Add skip connection from input and RNN layers to cause-specific layers.
#' @param dropout Vector of dropout rates after each hidden layer. Use 0 for no dropout (default).
#' @param dropout_rnn Vector of dropout rates after each recurrent layer. Use 0 for no dropout (default).
Expand All @@ -24,22 +24,24 @@
#' @param verbose Verbosity mode (0 = silent, 1 = progress bar, 2 = one line per epoch).
#'
#' @return Fitted model.
#' @examples
#' @examples
#' \donttest{
#' library(survival)
#' library(survnet)
#'
#'
#' # Survival data
#' y <- veteran[, c(3, 4)]
#' x <- veteran[, c(-2, -3, -4)]
#' x <- data.frame(lapply(x, scale))
#' breaks <- c(1, 50, 100, 200, 500, 1000)
#'
#'
#' # Fit simple model
#' fit <- survnet(y = y, x = x, breaks = breaks)
#' plot(fit$history)
#' }
#'
#' @export
#' @import survival keras
#' @import survival keras3
#' @importFrom magrittr freduce
survnet <- function(y,
x,
Expand All @@ -60,7 +62,7 @@ survnet <- function(y,
l2 = rep(0, length(units)),
l2_rnn = rep(0, length(units_rnn)),
l2_causes = rep(0, length(units_causes)),
optimizer = optimizer_rmsprop(lr = 0.001),
optimizer = optimizer_rmsprop(learning_rate = 0.001),
verbose = 2) {

# Force evaluation of dependent arguments
Expand Down Expand Up @@ -176,24 +178,18 @@ survnet <- function(y,
return_sequences <- TRUE
}
if (l2_rnn[i] > 0) {
kernel_regularizer <- regularizer_l2(l = l2_rnn[i])
kernel_regularizer <- regularizer_l2(l2 = l2_rnn[i])
} else {
kernel_regularizer <- NULL
}
if (rnn_type == "LSTM") {
layer_lstm(units = units_rnn[i], activation = activation, return_sequences = return_sequences,
layer_lstm(units = units_rnn[i], activation = activation, return_sequences = return_sequences,
kernel_regularizer = kernel_regularizer, name = paste0("rnn_", i))
} else if (rnn_type == "GRU") {
layer_gru(units = units_rnn[i], activation = activation, return_sequences = return_sequences,
layer_gru(units = units_rnn[i], activation = activation, return_sequences = return_sequences,
kernel_regularizer = kernel_regularizer, name = paste0("rnn_", i))
} else if (rnn_type == "CUDNN_LSTM") {
layer_cudnn_lstm(units = units_rnn[i], return_sequences = return_sequences,
kernel_regularizer = kernel_regularizer, name = paste0("rnn_", i))
} else if (rnn_type == "CUDNN_GRU") {
layer_cudnn_gru(units = units_rnn[i], return_sequences = return_sequences,
kernel_regularizer = kernel_regularizer, name = paste0("rnn_", i))
} else {
stop("Unknown rnn_type.")
stop("Unknown rnn_type. Use 'LSTM' or 'GRU'.")
}

})
Expand All @@ -208,7 +204,7 @@ survnet <- function(y,
# non-RNN layers
dense_layers <- lapply(1:length(units), function(i) {
if (l2[i] > 0) {
kernel_regularizer <- regularizer_l2(l = l2[i])
kernel_regularizer <- regularizer_l2(l2 = l2[i])
} else {
kernel_regularizer <- NULL
}
Expand Down Expand Up @@ -257,7 +253,7 @@ survnet <- function(y,
# Cause-specific layers
layers <- lapply(1:length(units_causes[[i]]), function(j) {
if (l2_causes[[i]][j] > 0) {
kernel_regularizer <- regularizer_l2(l = l2_causes[[i]][j])
kernel_regularizer <- regularizer_l2(l2 = l2_causes[[i]][j])
} else {
kernel_regularizer <- NULL
}
Expand Down
3 changes: 3 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.onLoad <- function(libname, pkgname) {
reticulate::py_require("tensorflow")
}
44 changes: 44 additions & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
---
title: "survnet: Artificial neural networks for survival analysis"
output: github_document
---

<!-- README.md is generated from README.Rmd. Please edit that file -->

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```

Marvin N. Wright

<!-- badges: start -->
[![R-CMD-check](https://github.com/bips-hb/survnet/actions/workflows/R-CMD-check.yaml/badge.svghttps://github.com/bips-hb/survnet/actions/workflows/R-CMD-check.yaml/badge.svghttps://github.com/bips-hb/survnet/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/bips-hb/survnet/actions/workflows/R-CMD-check.yaml)
<!-- badges: end -->


## Installation

```r
pak::pak("bips-hb/survnet")
```

## Example

```r
library(survnet)

# Survival data
y <- veteran[, c(3, 4)]
x <- veteran[, c(-2, -3, -4)]
x <- data.frame(lapply(x, scale))
breaks <- c(1, 50, 100, 200, 500, 1000)

# Fit simple model
fit <- survnet(y = y, x = x, breaks = breaks)
plot(fit$history)
```
Loading
Loading