https://lafor.ge/feed.xml

Refinement Type

2022-11-29

Bonjour à toutes et à tous 😀

Une fois n'est pas coutume, un article simple ^^

J'ai découvert ça cette nuit, du coup je vous le partage.

Il existe des cas où les types de base d'un langage ne sont pas suffisants, mais où l'utilisation d'une structure serait de trop.

C'est le contexte parfait pour dégainer les Refinement Types.

Solution Naïve

Un exemple pour comprendre.

Vous avez à mesurer des surfaces rectangulaires.

Pour une raison que je tairai ici, les capteurs sont un peu fatigués et renvoient parfois des valeurs négatives.

On veut qu'à ce moment-là, on ne prenne pas en compte la mesure et que l'on n'effectue pas non plus le calcul d'aire correspondant, sinon nous allons nous retrouver avec des aires négatives. Ce qui n'est pas top...

La solution naïve est de réaliser ceci :

fn calc_area(w: i8, h:i8) -> Option<i8> {
    if w < 0 || h < 0 {
        None
    } else {
        Some(w * h)
    }
}


fn main() {

    let witdths = [5, -22, -15, 3];
    let heights = [5, 2, -12, -7];


   for i in 0..4 {
       let area = calc_aire(witdths[i], heights[i]);
       println!("{:?}", area);
   }
}

On obtient bien notre résultat voulu, si l'un des capteurs est dans les choux, on invalide le résultat.

Some(25)
None
None
None

Encapsulation

Maintenant essayons autre chose.

Pourquoi ne pas enfermer dans une boîte notre donnée provenant du capteur.

Cette boîte appelons la WorkingSensor.

Elle ne contient qu'un champ Option<i8> qui symbolise si la valeur du capteur a été prise en compte ou non.

struct WorkingSensor {
    inner: Option<i8>
}

On lui rajoute un constructeur.

impl WorkingSensor {
    fn new(value: i8) -> Self {
        if value < 0 {
            Self {inner : None }
        } else {
            Self {inner: Some(value)}
        }
    }
}

Puis, l'on modifie notre fonction calc_area.

On y fait plusieurs choses.

D'abord, on change les paramètres d'entrées i8 en des références de WorkingSensor puis on utilise le champ inner de la structure pour venir vérifier que les données du couple de mesures sont corrects.

Finalement, on unwrap et on réalise la multiplication.

fn calc_area(w: &WorkingSensor, h: &WorkingSensor) -> Option<i8> {
    if w.inner.is_some() && h.inner.is_some() {
        Some(w.inner.unwrap() * h.inner.unwrap())
    } else {
        None
    }
}

On peut alors modifier le main.

On crée des Vec<WorkingSensor> et on boucle dessus.

fn main() {

    let witdths = [5, -22, -15, 3].into_iter()
        .map(WorkingSensor::new).collect::<Vec<WorkingSensor>>();
    let heights = [5, 2, -12, -7].into_iter()
        .map(WorkingSensor::new).collect::<Vec<WorkingSensor>>();


   for i in 0..4 {
       let area = calc_area(&witdths[i], &heights[i]);
       println!("{:?}", area);
   }
}

Optimisons tout ça.

Premièrement, on va se débarrasser des .inner qui polluent la lisibilité.

Pour ça on implémente le trait Deref.

use std::ops::Deref;

impl Deref for WorkingSensor {
    type Target = Option<i8>;

    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

Ce qui permet d'écrire :

fn calc_area(w: &WorkingSensor, h: &WorkingSensor) -> Option<i8> {
    if w.is_some() && h.is_some() {
        Some(w.unwrap() * h.unwrap())
    } else {
        None
    }
}

C'est mieux, mais le unwrap me dérange également.

Est-ce que l'on peut encapsuler la complexité ?

Oui! Biensûr ! 😁

impl Mul for &WorkingSensor {
    type Output = Option<i8>;

    fn mul(self, rhs: Self) -> Self::Output {
        match (self.inner, rhs.inner) {
            (Some(self_value), Some(rhs_value)) => Some(self_value * rhs_value),
            _ => None,
        }
    }
}

Ce qui simplifie grandement notre méthode calc_area.

fn calc_area(w: &WorkingSensor, h: &WorkingSensor) -> Option<i8> {
    w * h
}

Opérations conformes

Je ne sais pas vous, mais moi, j'aime bien lorsque les opérations mathématiques ne renvoient pas un type de donnée différent du type des entrées.

Ici, on multiplie des WorkingSensor par des WorkingSensor et cela nous donne un Option<i8>.

C'est un peu étrange. Mais ça serait tout aussi étrange de se retrouver avec un WorkingSensor comme aire possible.

Il faut que l'on généralise un peu, et le point commun entre un WorkingSensor et une aire, c'est que tous deux sont positifs.

Nous allons alors renommer notre structure en PositiveNumber. On lui définit également le trait Debug.

#[derive(Debug)]
struct PositiveNumber {
    inner: Option<i8>,
}

On peut ainsi implémenter notre multiplication comme on le souhaite.

impl Mul for &PositiveNumber {
    type Output = PositiveNumber;

    fn mul(self, rhs: Self) -> Self::Output {
        match (self.inner, rhs.inner) {
            (Some(self_value), Some(rhs_value)) => {
                PositiveNumber::new((self_value * rhs_value) as i8)
            }
            _ => PositiveNumber { inner: None },
        }
    }
}

Ce qui permet de réécrire notre signature de calc_area qui renvoie désormais un PositiveNumber.

fn calc_area(w: &PositiveNumber, h: &PositiveNumber) -> PositiveNumber {
    w * h
}

On active tout ça via un main.

Et ça nous donne :

fn main() {
    let widths = [5, -22, -15, 3]
        .into_iter()
        .map(PositiveNumber::new)
        .collect::<Vec<PositiveNumber>>();
    let heights = [5, 2, -12, -7]
        .into_iter()
        .map(PositiveNumber::new)
        .collect::<Vec<PositiveNumber>>();

    for i in 0..4 {
        let area = calc_area(&widths[i], &heights[i]);
        println!("{:?}", area);
    }
}

Avec comme affichage :

PositiveNumber { inner: Some(50) }
PositiveNumber { inner: None }     
PositiveNumber { inner: None }     
PositiveNumber { inner: None }  

Comme maintenant, nous avons une multiplication conforme, nous pouvons définir une opération de mise à l'échelle qui vient réaliser une multiplication scalaire par 2, par exemple de notre aire.

fn calc_area(w: &PositiveNumber, h: &PositiveNumber) -> PositiveNumber {
    w * h * 2
}

impl Mul<i8> for PositiveNumber {
    type Output = PositiveNumber;

    fn mul(self, rhs: i8) -> Self::Output {
        match self.inner {
            Some(self_value) => PositiveNumber::new(self_value * rhs),
            _ => PositiveNumber { inner: None },
        }
    }
}

Ce qui donne comme résultat :

PositiveNumber { inner: Some(200) }
PositiveNumber { inner: None }     
PositiveNumber { inner: None }     
PositiveNumber { inner: None }  

Refinement Type

C'est cool, mais je n'ai pas forcément envie de gérer tout le temps des entiers i8.

Nous allons généraliser !

D'abord, on change encore de nom.

Notre structure devient alors Refinement.

#[derive(Debug)]
struct Refinement<T> {
    inner: Option<T>,
}

On peut ainsi utiliser les génériques pour gérer n'importe quel type.

On généralise les implémentations.

impl<T> Deref for Refinement<T> {
    type Target = Option<T>;

    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}

// On doit spécifier le type T pour qu'il soit multipliable

impl<T> Mul for &Refinement<T>
where
    T: Clone + Copy + Mul<Output = T>,
{
    type Output = Refinement<T>;

    fn mul(self, rhs: Self) -> Self::Output {
        match (self.inner, rhs.inner) {
            (Some(self_value), Some(rhs_value)) => Refinement::new(self_value * rhs_value),
            _ => Refinement { inner: None },
        }
    }
}

// On doit spécifier le type T pour qu'il soit multipliable

impl<T> Mul<T> for Refinement<T>
where
    T: Clone + Copy + Mul<Output = T>,
{
    type Output = Refinement<T>;

    fn mul(self, rhs: T) -> Self::Output {
        match self.inner {
            Some(self_value) => Refinement::new(self_value * rhs),
            _ => Refinement { inner: None },
        }
    }
}

Notre méthode calc_area devient

fn calc_area<T: Clone + Copy + Mul<Output = T>>(
    w: &Refinement<T>,
    h: &Refinement<T>,
) -> Refinement<T> {
    w * h
}

Et on peut aussi réaliser la mise à l'échelle également.

fn calc_area<T: Clone + Copy + Mul<Output = T>>(
    w: &Refinement<T>,
    h: &Refinement<T>,
    scale: T
) -> Refinement<T> {
    w * h * scale
}

Par contre, là, le constructeur pose problème... 🙄

Lorsque l'on connaissait le type de la valeur d'entrée, on pouvait créer un check statique.

fn new(value: i16) -> Self {
    if value < 0 {
        Self { inner: None }
    } else {
        Self {
            inner: Some(value as i8),
        }
    }
}

Sauf que si value est de type T. T pouvant être tout et n'importe quoi, on ne peut plus définir de if qui pourrait correspondre à ce type T.

impl<T> Refinement<T> {
    fn new(value: T) -> Self {
        // ????
        Refinement { inner: None }
    }
}

On doit alors rajouter un paramètre qui nous sert de vérification.

En Rust, il est possible de définir un type générique pour une closure.

impl<T> Refinement<T> {
    fn new<F>(value: T, predicate: F) -> Self
    where
        F: Fn(&T) -> bool,
        T: Clone,
    {
        if predicate(value) {
            Refinement {
                inner: Some(value.clone()),
            }
        } else {
            Refinement { inner: None }
        }
    }
}

Mais ça pose des problèmes dans beaucoup d'autres parties du code. À commencer par le main.

Qui doit prendre le prédicat pour chaque new.

Ce n'est vraiment pas l'idéal.

fn main() {
    let widths = [10, -22, -15, 3]
        .into_iter()
        .map(|x| Refinement::new(x, ???))
        .collect::<Vec<Refinement<i8>>>();
    let heights = [10, 2, -12, -7]
        .into_iter()
        .map(|x| Refinement::new(x, ???))
        .collect::<Vec<Refinement<i8>>>();

    for i in 0..4 {
        let area = calc_area(&widths[i], &heights[i]);
        println!("{:?}", area);
    }
}

Et dans le Mul:

impl<T> Mul for &Refinement<T>
where
    T: Clone + Copy + Mul<Output = T>,
{
    type Output = Refinement<T>;

    fn mul(self, rhs: Self) -> Self::Output {
        match (self.inner, rhs.inner) {
            (Some(self_value), Some(rhs_value)) => Refinement::new(self_value * rhs_value, ???),
            _ => Refinement { inner: None },
        }
    }
}

Que mettre à la place de ??? et comment le faire de manière élégante.

On va essayer de trouver un moyen de résoudre ce problème.

Predicate

On crée un trait Predicate auquel l'on va définir une méthode check qui renvoie un bool.

Elle prend une référence &T pour être compatible avec toute entrée.

trait Predicate<T> {
    fn check(value: &T) -> bool;
}

On définit une structure vide

#[derive(Debug)]
struct PositiveNumber;

Que l'on implémente pour notre trait Predicate.

impl Predicate<i8> for PositiveNumber {
    fn check(value: &i8) -> bool {
        *value > 0
    }
}

On vient stocker le prédicat directement dans la structure Refinement.

Pour cela, on utilise le marker PhantomData<P>.

PhantomData est une syntaxe du langage qui permet de définir un type générique dans une structure, même si l'on n'a pas de champ pour le faire.

Sans cela, le Refinement<T, P> ne compilerait pas, car on n'aurait pas de champ de type P.

#[derive(Debug)]
struct Refinement<T, P> {
    inner: Option<T>,
    predicate: PhantomData<P>,
}

On modifie alors le constructeur.

impl<T, P> Refinement<T, P>
where
    P: Predicate<T>,
{
    fn new(value: T) -> Self
    where
        T: Clone,
    {
        if P::check(&value) {
            Refinement {
                inner: Some(value.clone()),
                predicate: PhantomData,
            }
        } else {
            Refinement {
                inner: None,
                predicate: PhantomData,
            }
        }
    }
}

Toute la mécanique est réalisée par le P::check. En effet P étant un Predicate<T>, nous avons l'assurance qu'il existe une méthode statique P::check qui prend une référence de T et par conséquent même en ne connaissant pas la nature du prédicat, nous pouvons tout de même l'appeler sans crainte. 😀

Et de fait, on peut alors définir la nature de ce P par l'inférence de type offerte par Rust, en définissant ce que l'on désire collect.

Ici un i8 vérifié par le prédicat PositiveNumber :

fn main() {
    let widths = [10, -22, -15, 3]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<Refinement<i8, PositiveNumber>>>();
    let heights = [10, 2, -12, -7]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<Refinement<i8, PositiveNumber>>>();

    for i in 0..4 {
        let area = calc_area(&widths[i], &heights[i]);
        println!("{:?}", area);
    }
}

Et si on lance !

Refinement { inner: None, predicate: PhantomData<refinement::PositiveNumber> }
Refinement { inner: None, predicate: PhantomData<refinement::PositiveNumber> }
thread 'main' panicked at 'attempt to multiply with overflow', /rustc/a6b7274a462829f8ef08a1ddcdcec7ac80dbf3e1\library\core\src\ops\arith.rs:349:1

Panic !!!

Ah oui! Overflow de multiplication.

Mais grâce à notre système, on peut modifier le type facilement.

On se crée un nouveau prédicat PositiveBigNumber qui va être en mesure de prendre un i64.

#[derive(Debug)]
struct PositiveBigNumber;

impl Predicate<i64> for PositiveBigNumber {
    fn check(value: &i64) -> bool {
        *value > 0
    }
}

Et un main tout propre. On règle l'overflow via un i64 au lieu d'un i8.

fn main() {
    let widths = [10, -22, -15, 3]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<Refinement<i64, PositiveBigNumber>>>();
    let heights = [10, 2, -12, -7]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<Refinement<i64, PositiveBigNumber>>>();

    for i in 0..4 {
        let area = calc_area(&widths[i], &heights[i]);
        println!("{:?}", area);
    }
}

On peut alors nettoyer notre code.

En rajoutant tout d'abord un type custom.

type BigPositiveNumber = Refinement<i64, PositiveBigNumber>;

Puis en utilisant celui-ci dans notre méthode calc_area.

fn calc_area(w: &BigPositiveNumber, h: &BigPositiveNumber) -> BigPositiveNumber {
    w * h
}

Et en modifiant le main en conséquences.

fn main() {
    let widths = [10, -22, -15, 3]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<BigPositiveNumber>>();
    let heights = [10, 2, -12, -7]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<BigPositiveNumber>>();

    for i in 0..4 {
        let area = calc_area(&widths[i], &heights[i]);
        println!("{:?}", area);
    }
}

Ce qui donne :

Refinement { inner: Some(100), predicate: PhantomData<refinement::PositiveBigNumber> }
Refinement { inner: None, predicate: PhantomData<refinement::PositiveBigNumber> }     
Refinement { inner: None, predicate: PhantomData<refinement::PositiveBigNumber> }     
Refinement { inner: None, predicate: PhantomData<refinement::PositiveBigNumber> }     

Cosmétique

On va rajouter une méthode de display.

On ajoute au trait Predicate une méthode error :

trait Predicate<T> {
    fn check(value: &T) -> bool;
    fn error() -> String;
}

Que l'on implémente sommairement.

#[derive(Debug)]
struct PositiveNumber;

impl Predicate<i8> for PositiveNumber {
    fn check(value: &i8) -> bool {
        *value > 0
    }

    fn error() -> String {
        "Must be a positive value".to_string()
    }
}

#[derive(Debug)]
struct PositiveBigNumber;

impl Predicate<i64> for PositiveBigNumber {
    fn check(value: &i64) -> bool {
        *value > 0
    }

    fn error() -> String {
        "Must be a positive value".to_string()
    }
}

On peut ensuite définir le trait Display pour notre structure Refinement.

impl<T, P> Display for Refinement<T, P>
where
    P: Predicate<T>,
    T: Display,
{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match &self.inner {
            Some(x) => f.write_fmt(format_args!("{}", x)),
            None => write!(f, "{}", P::error()),
        }
    }
}

Et enfin modifier dans le main, la boucle de prints pour utiliser le Display.

for i in 0..4 {
    let area = calc_area(&widths[i], &heights[i]);
    println!("{}", area);
}

Ce qui nous affiche alors :

100
Must be a positive value
Must be a positive value
Must be a positive value
code complet
use std::fmt::{Display, Formatter};
use std::marker::PhantomData;
use std::ops::Mul;

fn calc_area(w: &BigPositiveNumber, h: &BigPositiveNumber) -> BigPositiveNumber {
    w * h
}

trait Predicate<T> {
    fn check(value: &T) -> bool;
    fn error() -> String;
}

#[derive(Debug)]
struct PositiveNumber;

impl Predicate<i8> for PositiveNumber {
    fn check(value: &i8) -> bool {
        *value > 0
    }

    fn error() -> String {
        "Must be a positive value".to_string()
    }
}

#[derive(Debug)]
struct PositiveBigNumber;

impl Predicate<i64> for PositiveBigNumber {
    fn check(value: &i64) -> bool {
        *value > 0
    }

    fn error() -> String {
        "Must be a positive value".to_string()
    }
}

#[derive(Debug)]
struct Refinement<T, P> {
    inner: Option<T>,
    predicate: PhantomData<P>,
}

impl<T, P> Display for Refinement<T, P>
where
    P: Predicate<T>,
    T: Display,
{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match &self.inner {
            Some(x) => f.write_fmt(format_args!("{}", x)),
            None => write!(f, "{}", P::error()),
        }
    }
}

impl<T, P> Mul for &Refinement<T, P>
where
    T: Clone + Copy + Mul<Output = T>,
    P: Predicate<T>,
{
    type Output = Refinement<T, P>;

    fn mul(self, rhs: Self) -> Self::Output {
        match (self.inner, rhs.inner) {
            (Some(self_value), Some(rhs_value)) => Refinement::new(self_value * rhs_value),
            _ => Refinement {
                inner: None,
                predicate: PhantomData,
            },
        }
    }
}

impl<T, P> Mul<T> for Refinement<T, P>
where
    T: Clone + Copy + Mul<Output = T>,
    P: Predicate<T>,
{
    type Output = Refinement<T, P>;

    fn mul(self, rhs: T) -> Self::Output {
        match self.inner {
            Some(self_value) => Refinement::new(self_value * rhs),
            _ => Refinement {
                inner: None,
                predicate: PhantomData,
            },
        }
    }
}

impl<T, P> Refinement<T, P>
where
    P: Predicate<T>,
{
    fn new(value: T) -> Self
    where
        T: Clone,
    {
        if P::check(&value) {
            Refinement {
                inner: Some(value.clone()),
                predicate: PhantomData,
            }
        } else {
            Refinement {
                inner: None,
                predicate: PhantomData,
            }
        }
    }
}

type BigPositiveNumber = Refinement<i64, PositiveBigNumber>;

fn main() {
    let widths = [10, -22, -15, 3]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<BigPositiveNumber>>();
    let heights = [10, 2, -12, -7]
        .into_iter()
        .map(Refinement::new)
        .collect::<Vec<BigPositiveNumber>>();

    for i in 0..4 {
        let area = calc_area(&widths[i], &heights[i]);
        println!("{}", area);
    }
}

C'est quand même pas mal, non ? 😀

Conclusion

Les Refinement Types sont des objets qui permettent de s'assurer de la cohérence des données en mathématique on appellerait ceci un sous-ensemble.

Dans l'article notre prédicat était très simple. Mais l'on peut imaginer des prédicats très complexes qui permettent de valider des mots de passes par exemple.

On pourrait créer ce type par exemple :

type PasswordValid = Refinement<String, IsValidPassword>;

Je vous laisse imaginer les usages que vous pourriez en avoir.

Je vous remercie pour votre lecture et vous dis à la prochaine ❤️

avatar

Auteur: Akanoa

Je découvre, j'apprends, je comprends et j'explique ce que j'ai compris dans ce blog.

Ce travail est sous licence CC BY-NC-SA 4.0.