Skip to content

uqregressors.utils.activations

activations

get_activation(name)

A simple method to return neural network activations (Pytorch modules) from their name (string)

Parameters:

Name Type Description Default
name str

The activation function to return

required

Returns:

Type Description
Module

The activation function as a torch module

Source code in uqregressors\utils\activations.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def get_activation(name: str):
    """
    A simple method to return neural network activations (Pytorch modules) from their name (string)

    Args: 
        name (str): The activation function to return 

    Returns: 
        (Torch.nn.Module): The activation function as a torch module
    """
    name = name.lower()
    activations = {
        "relu": nn.ReLU,
        "leaky_relu": nn.LeakyReLU,
        "tanh": nn.Tanh,
        "sigmoid": nn.Sigmoid,
        "gelu": nn.GELU,
        "elu": nn.ELU,
        "selu": nn.SELU,
        "none": nn.Identity,
    }
    if name not in activations:
        raise ValueError(f"Unsupported activation: {name}")
    return activations[name]