From 23e42bdb216f8f33b82148967c9781cbc992cd17 Mon Sep 17 00:00:00 2001
From: m-guberina <gubi.guberina@gmail.com>
Date: Thu, 20 Feb 2025 01:57:49 +0100
Subject: [PATCH] disjoint control works. has to be tuned and optimized
 (bottlenecks are known), but we're going with this

---
 .../cart_pulling/disjoint_control/demo.py     | 101 ++++++++++++++
 .../mpc_base_clik_arm_control_loop.py         | 132 ++++++++++++++++++
 examples/navigation/mobile_base_navigation.py |  67 +++++++++
 .../path_following_template.py                |   4 +-
 .../optimal_control/abstract_croco_ocp.py     |  14 +-
 .../croco_mpc_path_following.py               | 106 +++++++++++---
 .../path_following_croco_ocp.py               |  49 +++++++
 python/smc/robots/implementations/heron.py    |  74 +---------
 .../interfaces/mobile_base_interface.py       |  90 ++++++++++++
 9 files changed, 545 insertions(+), 92 deletions(-)
 create mode 100644 examples/cart_pulling/disjoint_control/demo.py
 create mode 100644 examples/cart_pulling/disjoint_control/mpc_base_clik_arm_control_loop.py

diff --git a/examples/cart_pulling/disjoint_control/demo.py b/examples/cart_pulling/disjoint_control/demo.py
new file mode 100644
index 0000000..29ed9f6
--- /dev/null
+++ b/examples/cart_pulling/disjoint_control/demo.py
@@ -0,0 +1,101 @@
+from smc import getRobotFromArgs
+from smc import getMinimalArgParser
+from smc.path_generation.maps.premade_maps import createSampleStaticMap
+from smc.path_generation.path_math.path2d_to_6d import path2D_to_SE3_fixed
+from smc.control.optimal_control.util import get_OCP_args
+from smc.control.cartesian_space import getClikArgs
+from smc.path_generation.planner import starPlanner, getPlanningArgs
+from smc.control.optimal_control.croco_mpc_point_to_point import EEAndBaseP2PMPC
+from smc.multiprocessing import ProcessManager
+from mpc_base_clik_arm_control_loop import BaseMPCANDEECLIKCartPulling
+
+import time
+import numpy as np
+from functools import partial
+import pinocchio as pin
+
+
+def get_args():
+    parser = getMinimalArgParser()
+    parser = get_OCP_args(parser)
+    parser = getClikArgs(parser)  # literally just for goal error
+    parser = getPlanningArgs(parser)
+    parser.add_argument(
+        "--handlebar-height",
+        type=float,
+        default=0.5,
+        help="heigh of handlebar of the cart to be pulled",
+    )
+    parser.add_argument(
+        "--base-to-handlebar-preferred-distance",
+        type=float,
+        default=0.5,
+        help="prefered path arclength from mobile base position to handlebar",
+    )
+    args = parser.parse_args()
+    return args
+
+
+if __name__ == "__main__":
+    args = get_args()
+    robot = getRobotFromArgs(args)
+    # TODO: HOW IS IT POSSIBLE THAT T_W_E IS WRONG WITHOUT STEP CALLED HERE?????????????????
+    robot._step()
+    T_w_e = robot.T_w_e
+    robot._q[0] = 9.0
+    robot._q[1] = 4.0
+    robot._step()
+    x0 = np.concatenate([robot.q, robot.v])
+    goal = np.array([0.5, 5.5])
+
+    planning_function = partial(starPlanner, goal)
+    # here we're following T_w_e reference so that's what we send
+    path_planner = ProcessManager(
+        args, planning_function, T_w_e.translation[:2], 3, None
+    )
+    _, map_as_list = createSampleStaticMap()
+    if args.visualizer:
+        robot.sendRectangular2DMapToVisualizer(map_as_list)
+        # time.sleep(5)
+
+    T_w_e = robot.T_w_e
+    data = None
+    # get first path
+    ##########################################3
+    #                initialize
+    ###########################################
+    while data is None:
+        path_planner.sendCommand(T_w_e.translation[:2])
+        data = path_planner.getData()
+        # time.sleep(1 / args.ctrl_freq)
+        time.sleep(1)
+
+    _, path2D = data
+    path2D = np.array(path2D).reshape((-1, 2))
+    pathSE3 = path2D_to_SE3_fixed(path2D, args.handlebar_height)
+    if args.visualizer:
+        # TODO: document this somewhere
+        robot.visualizer_manager.sendCommand({"Mgoal": pathSE3[0]})
+    if np.linalg.norm(pin.log6(T_w_e.actInv(pathSE3[0]))) > 1e-2:
+        print("going to initial path position")
+        p_base = pathSE3[0].translation.copy()
+        p_base[0] -= args.base_to_handlebar_preferred_distance
+        p_base[2] = 0.0
+        print(pathSE3[0].translation)
+        print(p_base)
+        # TODO: UNCOMMENT
+        EEAndBaseP2PMPC(args, robot, pathSE3[0], p_base)
+    print("initialized!")
+    BaseMPCANDEECLIKCartPulling(args, robot, path_planner)
+
+    print("final position:", robot.T_w_e)
+
+    if args.real:
+        robot.stopRobot()
+
+    if args.save_log:
+        robot._log_manager.saveLog()
+        robot._log_manager.plotAllControlLoops()
+
+    if args.visualizer:
+        robot.killManipulatorVisualizer()
diff --git a/examples/cart_pulling/disjoint_control/mpc_base_clik_arm_control_loop.py b/examples/cart_pulling/disjoint_control/mpc_base_clik_arm_control_loop.py
new file mode 100644
index 0000000..7ac1045
--- /dev/null
+++ b/examples/cart_pulling/disjoint_control/mpc_base_clik_arm_control_loop.py
@@ -0,0 +1,132 @@
+from smc.robots.interfaces.whole_body_interface import SingleArmWholeBodyInterface
+from smc.control.control_loop_manager import ControlLoopManager
+from smc.multiprocessing.process_manager import ProcessManager
+from smc.control.optimal_control.abstract_croco_ocp import CrocoOCP
+from smc.control.optimal_control.path_following_croco_ocp import (
+    BasePathFollowingOCP,
+)
+from smc.path_generation.path_math.path2d_to_6d import (
+    path2D_to_SE3_fixed,
+)
+from smc.path_generation.path_math.cart_pulling_path_math import (
+    construct_EE_path,
+)
+from smc.path_generation.path_math.path_to_trajectory import path2D_timed
+from smc.control.controller_templates.path_following_template import (
+    PathFollowingFromPlannerControlLoop,
+)
+from smc.control.cartesian_space.ik_solvers import dampedPseudoinverse
+from smc.control.optimal_control.croco_mpc_path_following import initializePastData
+
+import numpy as np
+from functools import partial
+import types
+from argparse import Namespace
+from pinocchio import SE3, log6
+from collections import deque
+
+
+def BaseMPCEECLIKPathFollowingFromPlannerMPCControlLoop(
+    ocp: CrocoOCP,
+    path2D_untimed_base: np.ndarray,
+    args: Namespace,
+    robot: SingleArmWholeBodyInterface,
+    t: int,
+    past_data: dict[str, deque[np.ndarray]],
+) -> tuple[np.ndarray, dict[str, np.ndarray]]:
+
+    robot._mode = SingleArmWholeBodyInterface.control_mode.whole_body
+    p = robot.T_w_b.translation[:2]
+    max_base_v = np.linalg.norm(robot._max_v[:2])
+    path_base = path2D_timed(args, path2D_untimed_base, max_base_v)
+    path_base = np.hstack((path_base, np.zeros((len(path_base), 1))))
+
+    x0 = np.concatenate([robot.q, robot.v])
+    ocp.warmstartAndReSolve(x0, data=(path_base))
+    xs = np.array(ocp.solver.xs)
+    v_cmd = xs[1, robot.model.nq :]
+
+    pathSE3_handlebar = construct_EE_path(args, p, past_data["path2D_untimed"])
+    robot._mode = SingleArmWholeBodyInterface.control_mode.arm_only
+
+    T_w_e = robot.T_w_e
+    # first check whether we're at the goal
+    SEerror = T_w_e.actInv(pathSE3_handlebar[0])
+    err_vector = log6(SEerror).vector
+    J = robot.getJacobian()
+    # compute the joint velocities based on controller you passed
+    # qd = ik_solver(J, err_vector, past_qd=past_data['dqs_cmd'][-1])
+    v_arm = dampedPseudoinverse(1e-2, J, err_vector)
+    robot._mode = SingleArmWholeBodyInterface.control_mode.whole_body
+
+    v_cmd[3:] = v_arm
+
+    if args.visualizer:
+        if t % int(np.ceil(args.ctrl_freq / 25)) == 0:
+            robot.visualizer_manager.sendCommand({"path": path_base})
+            robot.visualizer_manager.sendCommand({"frame_path": pathSE3_handlebar})
+
+    err_vector_ee = log6(robot.T_w_e.actInv(pathSE3_handlebar[0]))
+    err_vector_base = np.linalg.norm(p - path_base[0][:2])  # z axis is irrelevant
+    log_item = {}
+    log_item["err_vec_ee"] = err_vector_ee
+    log_item["err_norm_ee"] = np.linalg.norm(err_vector_ee).reshape((1,))
+    log_item["err_norm_base"] = np.linalg.norm(err_vector_base).reshape((1,))
+    return v_cmd, log_item
+
+
+def BaseMPCANDEECLIKCartPulling(
+    args: Namespace,
+    robot: SingleArmWholeBodyInterface,
+    path_planner: ProcessManager | types.FunctionType,
+) -> None:
+    """
+    BaseAndEEPathFollowingMPC
+    -----
+    run mpc for a point-to-point inverse kinematics.
+    note that the actual problem is solved on
+    a dynamics level, and velocities we command
+    are actually extracted from the state x(q,dq).
+    """
+
+    T_w_e = robot.T_w_e
+    x0 = np.concatenate([robot.q, robot.v])
+    ocp = BasePathFollowingOCP(args, robot, x0)
+    ocp.solveInitialOCP(x0)
+
+    max_base_v = np.linalg.norm(robot._max_v[:2])
+
+    path2D_handlebar = initializePastData(args, T_w_e, robot.q[:2], float(max_base_v))
+
+    if type(path_planner) == types.FunctionType:
+        raise NotImplementedError
+    else:
+        get_position = lambda robot: robot.q[:2]
+        controlLoop = partial(
+            PathFollowingFromPlannerControlLoop,
+            path_planner,
+            get_position,
+            ocp,
+            BaseMPCEECLIKPathFollowingFromPlannerMPCControlLoop,
+            args,
+            robot,
+        )
+    log_item = {
+        "qs": np.zeros(robot.model.nq),
+        "dqs": np.zeros(robot.model.nv),
+        "err_vec_ee": np.zeros((6,)),
+        "err_norm_ee": np.zeros((1,)),
+        "err_norm_base": np.zeros((1,)),
+    }
+    save_past_dict = {"path2D_untimed": T_w_e.translation[:2]}
+    loop_manager = ControlLoopManager(
+        robot, controlLoop, args, save_past_dict, log_item
+    )
+
+    # actually put past data into the past window
+    loop_manager.past_data["path2D_untimed"].clear()
+    loop_manager.past_data["path2D_untimed"].extend(
+        path2D_handlebar[i] for i in range(args.past_window_size)
+    )
+
+    loop_manager.run()
diff --git a/examples/navigation/mobile_base_navigation.py b/examples/navigation/mobile_base_navigation.py
index e69de29..3685bcb 100644
--- a/examples/navigation/mobile_base_navigation.py
+++ b/examples/navigation/mobile_base_navigation.py
@@ -0,0 +1,67 @@
+from smc import getRobotFromArgs
+from smc.robots.interfaces.whole_body_interface import SingleArmWholeBodyInterface
+from smc import getMinimalArgParser
+from smc.path_generation.maps.premade_maps import createSampleStaticMap
+from smc.path_generation.path_math.path2d_to_6d import path2D_to_SE3_fixed
+from smc.control.optimal_control.util import get_OCP_args
+from smc.control.cartesian_space import getClikArgs
+from smc.path_generation.planner import starPlanner, getPlanningArgs
+from smc.control.optimal_control.croco_mpc_path_following import (
+    CrocoBasePathFollowingMPC,
+)
+from smc.multiprocessing import ProcessManager
+
+import time
+import numpy as np
+from functools import partial
+import pinocchio as pin
+
+
+def get_args():
+    parser = getMinimalArgParser()
+    parser = get_OCP_args(parser)
+    parser = getClikArgs(parser)  # literally just for goal error
+    parser = getPlanningArgs(parser)
+    args = parser.parse_args()
+    return args
+
+
+# DA BI OVAJ MPC SEX FUNKCIONIRAO, MORAS MU DAT PIN.MODEL SAMO BAZE!!!
+# ili to il ce se vrtit GRO sporo bez razloga
+
+if __name__ == "__main__":
+    args = get_args()
+    robot = getRobotFromArgs(args)
+    # TODO: for ocp you want to pass only the mobile base model
+    # robot._mode = SingleArmWholeBodyInterface.control_mode.base_only
+    robot._step()
+    robot._q[0] = 9.0
+    robot._q[1] = 4.0
+    robot._step()
+    x0 = np.concatenate([robot.q, robot.v])
+    goal = np.array([0.5, 5.5])
+    T_w_b = robot.T_w_b
+
+    planning_function = partial(starPlanner, goal)
+    # here we're following T_w_e reference so that's what we send
+    path_planner = ProcessManager(
+        args, planning_function, T_w_b.translation[:2], 3, None
+    )
+    _, map_as_list = createSampleStaticMap()
+    if args.visualizer:
+        robot.sendRectangular2DMapToVisualizer(map_as_list)
+        # time.sleep(5)
+
+    CrocoBasePathFollowingMPC(args, robot, x0, path_planner)
+
+    print("final position:", robot.T_w_b.translation)
+
+    if args.real:
+        robot.stopRobot()
+
+    if args.save_log:
+        robot._log_manager.saveLog()
+        robot._log_manager.plotAllControlLoops()
+
+    if args.visualizer:
+        robot.killManipulatorVisualizer()
diff --git a/python/smc/control/controller_templates/path_following_template.py b/python/smc/control/controller_templates/path_following_template.py
index e7c815e..b854f53 100644
--- a/python/smc/control/controller_templates/path_following_template.py
+++ b/python/smc/control/controller_templates/path_following_template.py
@@ -66,6 +66,8 @@ def PathFollowingFromPlannerControlLoop(
 
     log_item["qs"] = robot.q
     log_item["dqs"] = robot.v
-    # NOTE: shouldn't be here
+    # NOTE: shouldn't be here, but this temporarily makes my life easier
+    # sorry future guy looking at this
+    # TODO: enable if you want to see base and ee path following ocp in action
     save_past_item["path2D_untimed"] = p
     return breakFlag, save_past_item, log_item
diff --git a/python/smc/control/optimal_control/abstract_croco_ocp.py b/python/smc/control/optimal_control/abstract_croco_ocp.py
index 5669969..2aa452f 100644
--- a/python/smc/control/optimal_control/abstract_croco_ocp.py
+++ b/python/smc/control/optimal_control/abstract_croco_ocp.py
@@ -265,11 +265,11 @@ class CrocoOCP(abc.ABC):
         xs = [x0] * (self.solver.problem.T + 1)
         us = self.solver.problem.quasiStatic([x0] * self.solver.problem.T)
 
-        #start = time.time()
-        #self.solver.solve(xs, us, 500, False, 1e-9)
+        # start = time.time()
+        # self.solver.solve(xs, us, 500, False, 1e-9)
         self.solver.solve(xs, us, self.args.max_solver_iter)
-        #end = time.time()
-        #print("solved in:", end - start, "seconds")
+        # end = time.time()
+        # print("solved in:", end - start, "seconds")
 
     def getSolvedReference(self) -> dict[str, Any]:
         # solver.solve()
@@ -284,9 +284,11 @@ class CrocoOCP(abc.ABC):
     # NOTE: this is ugly, but idk how to deal with the fact that i don't know
     # which kind of arguments this function needs
     def updateCosts(self, data):
-        raise NotImplementedError("if you want to warmstart and resolve, you need \
+        raise NotImplementedError(
+            "if you want to warmstart and resolve, you need \
             to specify how do you update the cost function (could be nothing) \
-            in between resolving")
+            in between resolving"
+        )
 
     def warmstartAndReSolve(self, x0: np.ndarray, data=None) -> None:
         self.solver.problem.x0 = x0
diff --git a/python/smc/control/optimal_control/croco_mpc_path_following.py b/python/smc/control/optimal_control/croco_mpc_path_following.py
index 134e942..d163fb6 100644
--- a/python/smc/control/optimal_control/croco_mpc_path_following.py
+++ b/python/smc/control/optimal_control/croco_mpc_path_following.py
@@ -1,8 +1,12 @@
+from smc.robots.interfaces.mobile_base_interface import MobileBaseInterface
 from smc.robots.interfaces.single_arm_interface import SingleArmInterface
+from smc.robots.interfaces.whole_body_interface import SingleArmWholeBodyInterface
+from smc.robots.interfaces.mobile_base_interface import MobileBaseInterface
 from smc.control.control_loop_manager import ControlLoopManager
 from smc.multiprocessing.process_manager import ProcessManager
 from smc.control.optimal_control.abstract_croco_ocp import CrocoOCP
 from smc.control.optimal_control.path_following_croco_ocp import (
+    BasePathFollowingOCP,
     CrocoEEPathFollowingOCP,
     BaseAndEEPathFollowingOCP,
     BaseAndDualArmEEPathFollowingOCP,
@@ -26,14 +30,82 @@ from pinocchio import SE3, log6
 from collections import deque
 
 
+def CrocoBasePathFollowingFromPlannerMPCControlLoop(
+    ocp: CrocoOCP,
+    path2D_untimed: np.ndarray,
+    args: Namespace,
+    robot: MobileBaseInterface,
+    t: int,
+    _: dict[str, deque[np.ndarray]],
+) -> tuple[np.ndarray, dict[str, np.ndarray]]:
+
+    p = robot.T_w_b.translation[:2]
+    max_base_v = np.linalg.norm(robot._max_v[:2])
+    path_base = path2D_timed(args, path2D_untimed, max_base_v)
+    path_base = np.hstack((path_base, np.zeros((len(path_base), 1))))
+
+    if args.visualizer:
+        if t % int(np.ceil(args.ctrl_freq / 25)) == 0:
+            robot.visualizer_manager.sendCommand({"path": path_base})
+
+    x0 = np.concatenate([robot.q, robot.v])
+    ocp.warmstartAndReSolve(x0, data=(path_base))
+    xs = np.array(ocp.solver.xs)
+    v_cmd = xs[1, robot.model.nq :]
+
+    err_vector_base = np.linalg.norm(p - path_base[0][:2])  # z axis is irrelevant
+    log_item = {}
+    log_item["err_norm_base"] = np.linalg.norm(err_vector_base).reshape((1,))
+    return v_cmd, log_item
+
+def CrocoBasePathFollowingMPC(
+    args: Namespace,
+    robot: MobileBaseInterface,
+    x0: np.ndarray,
+    path_planner: ProcessManager | types.FunctionType,
+) -> None:
+    """
+    CrocoBasePathFollowingMPC
+    -----
+    """
+
+    ocp = BasePathFollowingOCP(args, robot, x0)
+    x0 = np.concatenate([robot.q, robot.v])
+    ocp.solveInitialOCP(x0)
+
+    if type(path_planner) == types.FunctionType:
+        raise NotImplementedError
+    else:
+        get_position = lambda robot: robot.T_w_b.translation[:2]
+        controlLoop = partial(
+            PathFollowingFromPlannerControlLoop,
+            path_planner,
+            get_position,
+            ocp,
+            CrocoBasePathFollowingFromPlannerMPCControlLoop,
+            args,
+            robot,
+        )
+    log_item = {
+        "qs": np.zeros(robot.nq),
+        "dqs": np.zeros(robot.nv),
+        "err_norm_base" : np.zeros((1,))
+    }
+    save_past_item = {}
+    loop_manager = ControlLoopManager(
+        robot, controlLoop, args, save_past_item, log_item
+    )
+    loop_manager.run()
+
+
 def CrocoEEPathFollowingMPCControlLoop(
-    args,
+    args: Namespace,
     robot: SingleArmInterface,
     ocp: CrocoOCP,
     path_planner: types.FunctionType,
-    t,
-    _,
-):
+    t: int,
+    _: dict[str, deque[np.ndarray]],
+) -> tuple[np.ndarray, dict[str, np.ndarray]]:
     """
     CrocoPathFollowingMPCControlLoop
     -----------------------------
@@ -78,7 +150,7 @@ def CrocoEEPathFollowingFromPlannerMPCControlLoop(
     args: Namespace,
     robot: SingleArmInterface,
     t: int,
-    _: dict[str, np.ndarray],
+    _: dict[str, deque[np.ndarray]],
 ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
     """
     CrocoPathFollowingMPCControlLoop
@@ -118,7 +190,7 @@ def CrocoEEPathFollowingMPC(
     robot: SingleArmInterface,
     x0: np.ndarray,
     path_planner: ProcessManager | types.FunctionType,
-):
+) -> None:
     """
     CrocoEndEffectorPathFollowingMPC
     -----
@@ -161,13 +233,13 @@ def CrocoEEPathFollowingMPC(
 
 
 def BaseAndEEPathFollowingMPCControlLoop(
-    args,
+    args: Namespace,
     robot,
     ocp: CrocoOCP,
     path_planner: types.FunctionType,
-    t,
-    _,
-):
+    t: int,
+    _: dict[str, deque[np.ndarray]],
+) -> tuple[np.ndarray, dict[str, np.ndarray]]:
     """
     CrocoPathFollowingMPCControlLoop
     -----------------------------
@@ -211,7 +283,7 @@ def BaseAndEEPathFollowingFromPlannerMPCControlLoop(
     ocp: CrocoOCP,
     path2D_untimed_base: np.ndarray,
     args: Namespace,
-    robot: SingleArmInterface,
+    robot: SingleArmWholeBodyInterface,
     t: int,
     past_data: dict[str, deque[np.ndarray]],
 ) -> tuple[np.ndarray, dict[str, np.ndarray]]:
@@ -224,10 +296,10 @@ def BaseAndEEPathFollowingFromPlannerMPCControlLoop(
 
     pathSE3_handlebar = construct_EE_path(args, p, past_data["path2D_untimed"])
 
-    #print("BASEcurrent_position", p)
-    #print("EEcurrent_position", robot.T_w_e.translation)
-    #print("=" * 5, "desired handlebar traj", "="  * 5)
-    #for pose in pathSE3_handlebar:
+    # print("BASEcurrent_position", p)
+    # print("EEcurrent_position", robot.T_w_e.translation)
+    # print("=" * 5, "desired handlebar traj", "="  * 5)
+    # for pose in pathSE3_handlebar:
     #    print(pose.translation)
 
     ###########################################
@@ -272,9 +344,9 @@ def initializePastData(
 
 def BaseAndEEPathFollowingMPC(
     args: Namespace,
-    robot: SingleArmInterface,
+    robot: SingleArmWholeBodyInterface,
     path_planner: ProcessManager | types.FunctionType,
-):
+) -> None:
     """
     BaseAndEEPathFollowingMPC
     -----
diff --git a/python/smc/control/optimal_control/path_following_croco_ocp.py b/python/smc/control/optimal_control/path_following_croco_ocp.py
index 6d75a49..092f783 100644
--- a/python/smc/control/optimal_control/path_following_croco_ocp.py
+++ b/python/smc/control/optimal_control/path_following_croco_ocp.py
@@ -1,5 +1,6 @@
 # TODO: make a bundle method which solves and immediately follows the traj.
 from smc.control.optimal_control.abstract_croco_ocp import CrocoOCP
+from smc.robots.interfaces.mobile_base_interface import MobileBaseInterface
 from smc.robots.robotmanager_abstract import AbstractRobotManager
 from smc.robots.interfaces.single_arm_interface import SingleArmInterface
 from smc.robots.interfaces.dual_arm_interface import DualArmInterface
@@ -9,6 +10,54 @@ import crocoddyl
 from argparse import Namespace
 
 
+class BasePathFollowingOCP(CrocoOCP):
+    """
+    createBaseAndEEPathFollowingOCP
+    -------------------------------
+    creates a path following problem.
+    it is instantiated to just to stay at the current position.
+    NOTE: the path MUST be time indexed with the SAME time used between the knots
+    """
+
+    def __init__(self, args, robot: MobileBaseInterface, x0):
+        goal = None
+        super().__init__(args, robot, x0, goal)
+
+    def constructTaskObjectiveFunction(self, goal) -> None:
+        path_base = [np.append(self.x0[:2], 0.0)] * self.args.n_knots
+
+        for i in range(self.args.n_knots):
+            baseTranslationResidual = crocoddyl.ResidualModelFrameTranslation(
+                self.state, self.robot.base_frame_id, path_base[i], self.state.nv
+            )
+            baseTrackingCost = crocoddyl.CostModelResidual(
+                self.state, baseTranslationResidual
+            )
+            self.runningCostModels[i].addCost(
+                "base_translation" + str(i),
+                baseTrackingCost,
+                self.args.base_translation_cost,
+            )
+
+        self.terminalCostModel.addCost(
+            "base_translation" + str(self.args.n_knots),
+            baseTrackingCost,
+            self.args.base_translation_cost,
+        )
+
+    def updateCosts(self, data):
+        path_base = data
+        for i, runningModel in enumerate(self.solver.problem.runningModels):
+            runningModel.differential.costs.costs[
+                "base_translation" + str(i)
+            ].cost.residual.reference = path_base[i]
+
+        # idk if that's necessary
+        self.solver.problem.terminalModel.differential.costs.costs[
+            "base_translation" + str(self.args.n_knots)
+        ].cost.residual.reference = path_base[-1]
+
+
 class CrocoEEPathFollowingOCP(CrocoOCP):
     """
     createCrocoEEPathFollowingOCP
diff --git a/python/smc/robots/implementations/heron.py b/python/smc/robots/implementations/heron.py
index 99cd46b..2f740df 100644
--- a/python/smc/robots/implementations/heron.py
+++ b/python/smc/robots/implementations/heron.py
@@ -2,7 +2,10 @@ from smc.robots.abstract_simulated_robot import AbstractSimulatedRobotManager
 from smc.robots.interfaces.force_torque_sensor_interface import (
     ForceTorqueOnSingleArmWrist,
 )
-from smc.robots.interfaces.mobile_base_interface import MobileBaseInterface
+from smc.robots.interfaces.mobile_base_interface import (
+    MobileBaseInterface,
+    get_mobile_base_model,
+)
 from smc.robots.interfaces.whole_body_interface import SingleArmWholeBodyInterface
 from smc.robots.implementations.ur5e import get_model
 
@@ -158,76 +161,11 @@ TODO: finish
 
 def heron_approximation():
     # arm + gripper
-    model_arm, collision_model_arm, visual_model_arm, data_arm = get_model()
+    model_arm, collision_model_arm, visual_model_arm, _ = get_model()
 
     # mobile base as planar joint (there's probably a better
     # option but whatever right now)
-    model_mobile_base = pin.Model()
-    model_mobile_base.name = "mobile_base"
-    geom_model_mobile_base = pin.GeometryModel()
-    joint_name = "mobile_base_planar_joint"
-    parent_id = 0
-    # TEST
-    joint_placement = pin.SE3.Identity()
-    # joint_placement.rotation = pin.rpy.rpyToMatrix(0, -np.pi/2, 0)
-    # joint_placement.translation[2] = 0.2
-    # TODO TODO TODO TODO TODO TODO TODO TODO
-    # TODO: heron is actually a differential drive,
-    # meaning that it is not a planar joint.
-    # we could put in a prismatic + revolute joint
-    # as the base (both joints being at same position),
-    # and that should work for our purposes.
-    # this makes sense for initial testing
-    # because mobile yumi's base is a planar joint
-    MOBILE_BASE_JOINT_ID = model_mobile_base.addJoint(
-        parent_id, pin.JointModelPlanar(), joint_placement.copy(), joint_name
-    )
-    # we should immediately set velocity limits.
-    # there are no position limit by default and that is what we want.
-    # TODO: put in heron's values
-    # TODO: make these parameters the same as in mpc_params in the planner
-    model_mobile_base.velocityLimit[0] = 2
-    # TODO: PUT THE CONSTRAINTS BACK!!!!!!!!!!!!!!!
-    model_mobile_base.velocityLimit[1] = 0
-    # model_mobile_base.velocityLimit[1] = 2
-    model_mobile_base.velocityLimit[2] = 2
-    # TODO: i have literally no idea what reasonable numbers are here
-    model_mobile_base.effortLimit[0] = 200
-    # TODO: PUT THE CONSTRAINTS BACK!!!!!!!!!!!!!!!
-    model_mobile_base.effortLimit[1] = 0
-    # model_mobile_base.effortLimit[1] = 2
-    model_mobile_base.effortLimit[2] = 200
-    # print("OBJECT_JOINT_ID",OBJECT_JOINT_ID)
-    # body_inertia = pin.Inertia.FromBox(args.box_mass, box_dimensions[0],
-    #        box_dimensions[1], box_dimensions[2])
-
-    # pretty much random numbers
-    # TODO: find heron (mir) numbers
-    body_inertia = pin.Inertia.FromBox(30, 0.5, 0.3, 0.4)
-    # maybe change placement to sth else depending on where its grasped
-    model_mobile_base.appendBodyToJoint(
-        MOBILE_BASE_JOINT_ID, body_inertia, pin.SE3.Identity()
-    )
-    box_shape = fcl.Box(0.5, 0.3, 0.4)
-    body_placement = pin.SE3.Identity()
-    geometry_mobile_base = pin.GeometryObject(
-        "box_shape", MOBILE_BASE_JOINT_ID, box_shape, body_placement.copy()
-    )
-
-    geometry_mobile_base.meshColor = np.array([1.0, 0.1, 0.1, 1.0])
-    geom_model_mobile_base.addGeometryObject(geometry_mobile_base)
-
-    # have to add the frame manually
-    model_mobile_base.addFrame(
-        pin.Frame(
-            "mobile_base",
-            MOBILE_BASE_JOINT_ID,
-            0,
-            joint_placement.copy(),
-            pin.FrameType.JOINT,
-        )
-    )
-
+    model_mobile_base, geom_model_mobile_base = get_mobile_base_model(True)
     # frame-index should be 1
     model, visual_model = pin.appendModel(
         model_mobile_base,
diff --git a/python/smc/robots/interfaces/mobile_base_interface.py b/python/smc/robots/interfaces/mobile_base_interface.py
index 2593f6a..1ad00cb 100644
--- a/python/smc/robots/interfaces/mobile_base_interface.py
+++ b/python/smc/robots/interfaces/mobile_base_interface.py
@@ -3,6 +3,7 @@ from smc.robots.robotmanager_abstract import AbstractRobotManager
 import numpy as np
 import pinocchio as pin
 from argparse import Namespace
+import hppfcl as fcl
 
 
 class MobileBaseInterface(AbstractRobotManager):
@@ -70,3 +71,92 @@ class MobileBaseInterface(AbstractRobotManager):
     #    self._updateQ()
     #    self._updateV()
     #    self.forwardKinematics()
+
+
+def get_mobile_base_model(underactuated: bool) -> tuple[pin.Model, pin.GeometryModel]:
+
+    # mobile base as planar joint (there's probably a better
+    # option but whatever right now)
+    model_mobile_base = pin.Model()
+    model_mobile_base.name = "mobile_base"
+    geom_model_mobile_base = pin.GeometryModel()
+    joint_name = "mobile_base_planar_joint"
+    parent_id = 0
+    # TEST
+    joint_placement = pin.SE3.Identity()
+    # joint_placement.rotation = pin.rpy.rpyToMatrix(0, -np.pi/2, 0)
+    # joint_placement.translation[2] = 0.2
+    # TODO TODO TODO TODO TODO TODO TODO TODO
+    # TODO: heron is actually a differential drive,
+    # meaning that it is not a planar joint.
+    # we could put in a prismatic + revolute joint
+    # as the base (both joints being at same position),
+    # and that should work for our purposes.
+    # this makes sense for initial testing
+    # because mobile yumi's base is a planar joint
+    MOBILE_BASE_JOINT_ID = model_mobile_base.addJoint(
+        parent_id, pin.JointModelPlanar(), joint_placement.copy(), joint_name
+    )
+    # we should immediately set velocity limits.
+    # there are no position limit by default and that is what we want.
+    # TODO: put in heron's values
+    # TODO: make these parameters the same as in mpc_params in the planner
+    if underactuated:
+        model_mobile_base.velocityLimit[0] = 2
+        # TODO: PUT THE CONSTRAINTS BACK!!!!!!!!!!!!!!!
+        model_mobile_base.velocityLimit[1] = 0
+        # model_mobile_base.velocityLimit[1] = 2
+        model_mobile_base.velocityLimit[2] = 2
+        # TODO: i have literally no idea what reasonable numbers are here
+        model_mobile_base.effortLimit[0] = 200
+        # TODO: PUT THE CONSTRAINTS BACK!!!!!!!!!!!!!!!
+        model_mobile_base.effortLimit[1] = 0
+        # model_mobile_base.effortLimit[1] = 2
+        model_mobile_base.effortLimit[2] = 200
+        # print("OBJECT_JOINT_ID",OBJECT_JOINT_ID)
+        # body_inertia = pin.Inertia.FromBox(args.box_mass, box_dimensions[0],
+        #        box_dimensions[1], box_dimensions[2])
+    else:
+        model_mobile_base.velocityLimit[0] = 2
+        # TODO: PUT THE CONSTRAINTS BACK!!!!!!!!!!!!!!!
+        model_mobile_base.velocityLimit[1] = 2
+        # model_mobile_base.velocityLimit[1] = 2
+        model_mobile_base.velocityLimit[2] = 2
+        # TODO: i have literally no idea what reasonable numbers are here
+        model_mobile_base.effortLimit[0] = 200
+        # TODO: PUT THE CONSTRAINTS BACK!!!!!!!!!!!!!!!
+        model_mobile_base.effortLimit[1] = 200
+        # model_mobile_base.effortLimit[1] = 2
+        model_mobile_base.effortLimit[2] = 200
+        # print("OBJECT_JOINT_ID",OBJECT_JOINT_ID)
+        # body_inertia = pin.Inertia.FromBox(args.box_mass, box_dimensions[0],
+        #        box_dimensions[1], box_dimensions[2])
+
+    # pretty much random numbers
+    # TODO: find heron (mir) numbers
+    body_inertia = pin.Inertia.FromBox(30, 0.5, 0.3, 0.4)
+    # maybe change placement to sth else depending on where its grasped
+    model_mobile_base.appendBodyToJoint(
+        MOBILE_BASE_JOINT_ID, body_inertia, pin.SE3.Identity()
+    )
+    box_shape = fcl.Box(0.5, 0.3, 0.4)
+    body_placement = pin.SE3.Identity()
+    geometry_mobile_base = pin.GeometryObject(
+        "box_shape", MOBILE_BASE_JOINT_ID, box_shape, body_placement.copy()
+    )
+
+    geometry_mobile_base.meshColor = np.array([1.0, 0.1, 0.1, 1.0])
+    geom_model_mobile_base.addGeometryObject(geometry_mobile_base)
+
+    # have to add the frame manually
+    model_mobile_base.addFrame(
+        pin.Frame(
+            "mobile_base",
+            MOBILE_BASE_JOINT_ID,
+            0,
+            joint_placement.copy(),
+            pin.FrameType.JOINT,
+        )
+    )
+
+    return model_mobile_base, geom_model_mobile_base
-- 
GitLab