C言語からのHaskell関数の呼び出し。その2。

構造体でのデータの受け渡し。加えて、hsc2hsを使ってcとHaskellの接点を切り離し。

CC = gcc -c
HC = ghc
HSC2HS = hsc2hs

RM = rm -f
TARGET = projectile

C_SRCS  = main.c
HS_SRCS = linearspace.hs projectile_wrapper.hs projectile.hs 

HS_OBJS = $(HS_SRCS:.hs=.o)
OBJS = $(C_SRCS:.c=.o) $(HS_OBJS)

all : $(OBJS)
	$(HC) --make -v -no-hs-main $^ -o $(TARGET) -package Tensor

projectile_wrapper.hs : projectile_wrapper.hsc
	$(HSC2HS) $^

projectile_wrapper_stub.h: projectile_wrapper.o

main.o: main.c projectile_wrapper_stub.h
	$(CC) $(CFLAGS) -I`ghc --print-libdir`/include -c $< -o $@

$(HS_OBJS): $(HS_SRCS)
	$(HC) $^

clean:
	$(RM) $(TARGET) *.o *~ *.hi *_stub.h projectile_wrapper.hs

以下のhscファイルで、Haskell側とC側のデータ構造を翻訳

#include "projectile.h"

#let alignment t = "%lu", (unsigned long)offsetof(struct {char x__; t (y__); }, y__)
module ProjectileForC where

import Foreign.C
import Foreign.Ptr
import Foreign.Storable
import Data.Tensor

import Projectile

foreign export ccall projectileForC :: Ptr ProjectileForC.ProjectileParam -> CDouble -> Ptr Vector3d -> IO CInt

{- projectile.hのVector3dにあたる -}
data Vector3d = Vector3d { x :: CDouble,
                           y :: CDouble,
                           z :: CDouble } deriving (Show,Eq)

{- projectile.hのProjectileParamにあたる -}
data ProjectileParam = ProjectileParam { position :: Vector3d,
                                     velocity :: Vector3d } deriving (Show,Eq)

instance Storable Vector3d where 
  sizeOf _ = #{size Vector3d}
  alignment _ = #{alignment Vector3d}
  peek ptr = do 
    x' <- (#peek struct Vector3d, x) ptr
    y' <- (#peek struct Vector3d, y) ptr
    z' <- (#peek struct Vector3d, z) ptr
    return (Vector3d x' y' z')
  poke ptr (Vector3d a b c) = do 
    (#poke struct Vector3d, x) ptr a
    (#poke struct Vector3d, y) ptr b
    (#poke struct Vector3d, z) ptr c

instance Storable ProjectileForC.ProjectileParam where
  sizeOf _ = #{size ProjectileParam}
  alignment _ = #{alignment ProjectileParam}
  peek ptr = do 
    p' <- (#peek struct ProjectileParam, position) ptr
    v' <- (#peek struct ProjectileParam, velocity) ptr
    return (ProjectileForC.ProjectileParam p' v')
  poke ptr (ProjectileForC.ProjectileParam p v) = do 
    (#poke struct ProjectileParam, position) ptr p
    (#poke struct ProjectileParam, velocity) ptr v
    
vector3dToVector3 :: Vector3d -> Vector3 Double
vector3dToVector3 v'@(Vector3d vx' vy' vz') = Vector3 (realToFrac vx') (realToFrac vy') (realToFrac vz')

vector3ToVector3d :: Vector3 Double -> Vector3d
vector3ToVector3d v'@(Vector3 vx' vy' vz') = Vector3d (realToFrac vx') (realToFrac vy') (realToFrac vz')

projectileParamTranslateFromForC :: ProjectileForC.ProjectileParam -> Projectile.ProjectileParam
projectileParamTranslateFromForC param' = 
  let p  = vector3dToVector3 $ ProjectileForC.position param'
      v  = vector3dToVector3 $ ProjectileForC.velocity param'
  in Projectile.ProjectileParam p v 

projectileParamTranslateFromToC :: Projectile.ProjectileParam -> ProjectileForC.ProjectileParam
projectileParamTranslateFromToC param' = 
  let p  = vector3ToVector3d $ Projectile.position param'
      v  = vector3ToVector3d $ Projectile.velocity param'
  in ProjectileForC.ProjectileParam p v

{- 返り値が1なら続行,0なら終了,エラーなら負 -}
projectileForC :: Ptr ProjectileForC.ProjectileParam -> CDouble -> Ptr Vector3d -> IO CInt
projectileForC ptrParam t ptrAddAccel = do
  currentParam <- peek ptrParam
  addAccel     <- peek ptrAddAccel
  let p' = (projectileParamTranslateFromForC currentParam) 
      t' = (realToFrac t) 
      a' = (vector3dToVector3 addAccel)
  (result, newParam) <- (\p t a-> do return (Projectile.projectile p t a)) p' t' a'
  poke ptrParam $ projectileParamTranslateFromToC newParam
  return $ fromIntegral result

Haskell側の本体

{-# LANGUAGE TypeFamilies #-}

module Projectile where

import Data.Int
import Data.Tensor
import Data.LinearSpace

data ProjectileParam = ProjectileParam { position :: Vector3 Double ,
                                         velocity :: Vector3 Double } deriving (Show)
{- Vector3に線形空間クラスを適用します -}
instance (Fractional a) => AdditiveGroup (Vector3 a) where
  zero = Vector3 0 0 0
  v1@(Vector3 x1 y1 z1) ^+^ v2@(Vector3 x2 y2 z2) =
    Vector3 (x1 + x2) (y1 + y2) (z1 + z2)
  negateE v@(Vector3 x y z) = Vector3 (-x) (-y) (-z)

instance (Fractional a) => ScalarMultiplicativeGroup (Vector3 a) where 
  type Scalar (Vector3 a) = a
  s *^ v@(Vector3 x y z) = Vector3 (s*x) (s*y) (s*z)

projectile :: ProjectileParam -> Double -> Vector3 Double -> (Int32, ProjectileParam)
projectile param dt addAccel = 
  let gravAccel    = Vector3 0.0 0.0 (-9.8)
      curVelocity  = velocity param
      curPosition  = position param
      newVelocity  = curVelocity ^+^ (dt*^(gravAccel ^+^ addAccel))
      newPosition  = curPosition ^+^ (dt*^newVelocity)
      result       = exitJudge newPosition
  in (result, ProjectileParam newPosition newVelocity)
    where exitJudge p@(Vector3 x y z) = 
            if z < 0 then 0 else 1 

呼出側のC

#ifndef _PROJECTILE_H_
#define _PROJECTILE_H_

typedef struct Vector3d {
	double x;
	double y;
	double z;
} Vector3d;

typedef struct ProjectileParam {
	Vector3d position;
	Vector3d velocity;
} ProjectileParam;

#endif /* _PROJECTILE_H_ */
#include <stdio.h>
#include <stdlib.h>

#include "projectile.h"
#include "projectile_wrapper_stub.h"

int main(int argc, char **argv){
	unsigned int i;
	const double deltaTime = 0.001;
	ProjectileParam param = {
		.position = { .x = 0,
					  .y = 0,
					  .z = 0 },
		.velocity = { .x = 0,
					  .y = 0,
					  .z = 0 },
	};
	// 初期加速度、(初期速度)
	// 0.001secで28*10^3m/s(約100km/h)まで到達する加速度
	Vector3d initAccel = {
		.x = 14*1000,
		.y = 14*1000,
		.z = 19.6*1000
	};
	Vector3d defaultAccel = { 
		.x = 0,
		.y = 0,
		.z = 0,
	};	
	int hsArgc = 1;
	char *hsArgv[] = { argv[0], NULL };
	char **pHsArgv = hsArgv;
	int result;

	hs_init(&hsArgc, &pHsArgv);

	printf("# start projectile\n");
	printf("# x    y    z\n");
	for(i=0;i<30000;i++){
		if(i != 0){
			result = projectileForC(&param, deltaTime, &defaultAccel);
		}
		else {
			result = projectileForC(&param, deltaTime, &initAccel);
		}
#if 0
		printf("%f,%f,%f,%f,%f,%f,%f\n",
			   deltaTime*i,
			   param.position.x, param.position.y, param.position.z,
			   param.velocity.x, param.velocity.y, param.velocity.z);
#else
		printf("%f  %f  %f\n",
			   param.position.x, param.position.y, param.position.z);
#endif
		if(result == 0){
			printf("# reach ending condition\n");
			break;
		}
	}

	hs_exit();
	return 0;
}