Source code for gamspy.formulations.nn.torch_sequential

from __future__ import annotations

import typing
from functools import partial
from typing import TYPE_CHECKING

import gamspy as gp
from gamspy.exceptions import ValidationError

if TYPE_CHECKING:
    import torch


def convert_linear(m: gp.Container, layer: torch.nn.Linear) -> gp.formulations.Linear:
    has_bias = layer.bias is not None
    l = gp.formulations.Linear(
        m,
        in_features=layer.in_features,
        out_features=layer.out_features,
        bias=has_bias,
    )
    l.load_weights(layer.weight.numpy(), layer.bias.numpy() if has_bias else None)
    return l


def convert_conv1d(m: gp.Container, layer: torch.nn.Conv1d) -> gp.formulations.Conv1d:
    if layer.dilation[0] != 1:
        raise ValidationError("Conv1d is not supported when dilation is not 1")

    if layer.groups != 1:
        raise ValidationError("Conv1d is not supported when groups is not 1")

    if layer.padding_mode != "zeros":
        raise ValidationError("Conv1d is only supported with padding_mode zeros")

    has_bias = layer.bias is not None
    l = gp.formulations.Conv1d(
        m,
        in_channels=layer.in_channels,
        out_channels=layer.out_channels,
        kernel_size=layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding,
        bias=has_bias,
    )

    l.load_weights(layer.weight.numpy(), layer.bias.numpy() if has_bias else None)
    return l


def convert_conv2d(m: gp.Container, layer: torch.nn.Conv2d) -> gp.formulations.Conv2d:
    if layer.dilation[0] != 1 or layer.dilation[-1] != 1:
        raise ValidationError("Conv2d is not supported when dilation is not 1")

    if layer.groups != 1:
        raise ValidationError("Conv2d is not supported when groups is not 1")

    if layer.padding_mode != "zeros":
        raise ValidationError("Conv1d is only supported with padding_mode zeros")

    has_bias = layer.bias is not None
    l = gp.formulations.Conv2d(
        m,
        in_channels=layer.in_channels,
        out_channels=layer.out_channels,
        kernel_size=layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding,
        bias=has_bias,
    )

    l.load_weights(layer.weight.numpy(), layer.bias.numpy() if has_bias else None)
    return l


def convert_relu(m: gp.Container, layer: torch.nn.ReLU):
    return gp.math.relu_with_binary_var


def convert_leaky_relu(m: gp.Container, layer: torch.nn.LeakyReLU):
    return partial(
        gp.math.leaky_relu_with_binary_var, negative_slope=layer.negative_slope
    )


def convert_pool2d(m: gp.Container, layer: torch.nn.MaxPool2d | torch.nn.AvgPool2d):
    clz = layer.__class__.__name__
    if clz == "MaxPool2d":
        dilation = layer.dilation
        if isinstance(dilation, int):
            dilation = (dilation,)

        if dilation[0] != 1 or dilation[-1] != 1:
            raise ValidationError("Pool2d is not supported when dilation is not 1")

        if layer.return_indices is True:
            raise ValidationError("Pool2d is not supported when return_indices is True")

    if layer.ceil_mode is True:
        raise ValidationError("Pool2d is not supported when ceil_mode is True")

    if clz == "MaxPool2d":
        return gp.formulations.MaxPool2d(
            m,
            kernel_size=layer.kernel_size,
            stride=layer.stride,
            padding=layer.padding,
        )
    else:
        return gp.formulations.AvgPool2d(
            m,
            kernel_size=layer.kernel_size,
            stride=layer.stride,
            padding=layer.padding,
        )


_DEFAULT_CONVERTERS = {
    "Linear": convert_linear,
    "Conv1d": convert_conv1d,
    "Conv2d": convert_conv2d,
    "ReLU": convert_relu,
    "LeakyReLU": convert_leaky_relu,
    "MaxPool2d": convert_pool2d,
    "AvgPool2d": convert_pool2d,
}


[docs] class TorchSequential: """ Formulation generator for Sequential Layer from PyTorch. This is a convenience formulation that builds upon other formulations. Parameters ---------- container : Container Container that will contain the new variable and equations. network : torch.nn.Sequential Sequential network that will be translated to GAMSPy layer_converters : dict | None You can change default layer converters or add support for not implemented layers through this dictionary. Key is the class name as string, and value expects a function that returns GAMSPy formulation given container and the PyTorch layer. Examples -------- >>> import gamspy as gp >>> from gamspy.math import dim >>> def embed(): ... try: ... import torch ... except ModuleNotFoundError as e: ... print("[10, 4, 30, 30]") ... return ... m = gp.Container() ... model = torch.nn.Sequential( ... torch.nn.Conv2d(3, 4, 3, bias=True), ... torch.nn.ReLU(), ... torch.nn.Conv2d(4, 4, 3, bias=False, padding=1), ... ) ... x = gp.Variable(m, domain=dim([10, 3, 32, 32])) ... seq_formulation = gp.formulations.TorchSequential(m, model) ... y, eqs = seq_formulation(x) ... print([len(d) for d in y.domain]) >>> embed() [10, 4, 30, 30] """ def __init__( self, container: gp.Container, network: torch.nn.Sequential, layer_converters: dict | None = None, ): try: import torch except ModuleNotFoundError as e: raise ValidationError( "You must first install PyTorch to use this functionality." ) from e self._layer_converters = _DEFAULT_CONVERTERS.copy() if layer_converters is not None: self._layer_converters.update(layer_converters) with torch.no_grad(): self.layers = [self._convert_layer(container, layer) for layer in network] def _convert_layer(self, container: gp.Container, layer): clz = layer.__class__.__name__ if clz not in self._layer_converters: raise ValidationError(f"Formulation for {clz} not implemented!") l = self._layer_converters[clz](container, layer) return l def _update_dict_layered( self, large: dict[str, typing.Any], small: dict[str, typing.Any], layer_num: int ): for key in small: large[f"{layer_num}.{key}"] = small[key]
[docs] def __call__(self, input: gp.Variable) -> gp.formulations.FormulationResult: """ This method returns a **`FormulationResult`** object, which includes symbols and outputs created by its underlying layers. The way to access these underlying symbols depends on what the sub-layer returns: 1. **If a Sub-Layer Returns a `FormulationResult`** All symbols created by that sub-layer can be accessed within the main `FormulationResult`. Each symbol's name is **prefixed** with its layer number, followed by a dot (`.`). * **Access Format:** `<layer_num>.<symbol_name>` * **Example:** If the first layer creates a parameter named `bias`, it is accessed as `0.bias` in `parameters_created`. 2. **If a Sub-Layer Returns the "Old Style" Output (Output Variable and List of Equations)** For backward compatibility, if a sub-layer returns an output variable and a list of equations instead of a `FormulationResult`, they are accessed as follows: * **Output Variable:** The main output variable is named: * **Access Format:** `<layer_num>.output` * **Equations:** Each returned equation is sequentially named: * **Access Format:** `<layer_num>.eq_<eq_number>` (where `eq_number` starts at 0, 1, 2...) * **Example:** The first equation from the third layer is accessed as `2.eq_0` in `equations_created`. Returns ------- FormulationResult """ result = gp.formulations.FormulationResult() out: gp.Variable | gp.Parameter | None = input for layer_num, layer in enumerate(self.layers): output = layer(out) if isinstance(output, gp.formulations.FormulationResult): self._update_dict_layered( result.equations_created, output.equations_created, layer_num ) self._update_dict_layered( result.variables_created, output.variables_created, layer_num ) self._update_dict_layered( result.sets_created, output.sets_created, layer_num ) self._update_dict_layered( result.parameters_created, output.parameters_created, layer_num ) self._update_dict_layered(result.other, output.other, layer_num) # matches are not named result.matches.update(output.matches) out = output.result else: # old interface only provides output var and list of equations out, layer_eqs = output result.variables_created[f"{layer_num}.output"] = out for eq_num, eq in enumerate(layer_eqs): result.equations_created[f"{layer_num}.eq_{eq_num}"] = eq result.result = out # to prevent forgetting matches when unpacking if result.matches: result.extra_return = result.matches return result