From b621207a3fdfe1e4e5693e5a9f5723e50d9dadb3 Mon Sep 17 00:00:00 2001
From: Anders Blomdell <anders.blomdell@control.lth.se>
Date: Mon, 17 Nov 2014 21:40:53 +0100
Subject: [PATCH] Made sample references a primitive type. C and python passes
 some tests...

---
 compiler/CS_CodeGen.jrag                      |  16 +-
 compiler/C_CodeGen.jrag                       |  60 +++--
 compiler/Java_CodeGen.jrag                    |   8 +-
 compiler/LabCommParser.parser                 |   7 +-
 compiler/LabCommTokens.jrag                   |   1 +
 compiler/Python_CodeGen.jrag                  |   5 +-
 lib/c/labcomm_encoder.c                       |   9 +-
 lib/c/test/test_labcomm_generated_encoding.c  |   2 +-
 lib/csharp/se/lth/control/labcomm/Decoder.cs  |   1 +
 .../se/lth/control/labcomm/DecoderChannel.cs  |   4 +
 lib/csharp/se/lth/control/labcomm/Encoder.cs  |   1 +
 .../se/lth/control/labcomm/EncoderChannel.cs  |   6 +
 lib/java/se/lth/control/labcomm/Decoder.java  |   1 +
 .../lth/control/labcomm/DecoderChannel.java   |   7 +
 lib/java/se/lth/control/labcomm/Encoder.java  |   1 +
 .../lth/control/labcomm/EncoderChannel.java   |   7 +
 lib/python/labcomm/LabComm.py                 | 147 +++++++++---
 lib/python/labcomm/__init__.py                |   5 +-
 test/relay_gen_c.py                           |   4 +-
 test/relay_gen_java.py                        |   1 +
 test/test_encoder_decoder.py                  | 211 +++++++++++-------
 21 files changed, 333 insertions(+), 171 deletions(-)

diff --git a/compiler/CS_CodeGen.jrag b/compiler/CS_CodeGen.jrag
index 0423e06..adcb0c6 100644
--- a/compiler/CS_CodeGen.jrag
+++ b/compiler/CS_CodeGen.jrag
@@ -546,6 +546,7 @@ aspect CS_Class {
       case LABCOMM_FLOAT: { env.print("e.encodeFloat"); } break;
       case LABCOMM_DOUBLE: { env.print("e.encodeDouble"); } break;
       case LABCOMM_STRING: { env.print("e.encodeString"); } break;
+      case LABCOMM_SAMPLE: { env.println("e.encodeSampleRef"); } break;
     }
     env.println("(" + name + ");");
   }
@@ -640,6 +641,11 @@ aspect CS_Class {
       case LABCOMM_FLOAT: { env.println("d.decodeFloat();"); } break;
       case LABCOMM_DOUBLE: { env.println("d.decodeDouble();"); } break;
       case LABCOMM_STRING: { env.println("d.decodeString();"); } break;
+      case LABCOMM_SAMPLE: { env.println("d.decodeSampleRef();"); } break;
+      default: {
+        throw new Error("PrimType.CS_emitDecoder(CS_env env, String name)" + 
+                        " unknown token type");
+      }
     }
   }
 
@@ -720,13 +726,10 @@ aspect CS_Class {
 		    " not declared");
   }
 
-  public void SampleRefType.CS_emitTypePrefix(CS_env env) {
-    env.print("Sample");
-  }
-
   public void PrimType.CS_emitTypePrefix(CS_env env) {
     switch (getToken()) {
       case LABCOMM_STRING: { env.print("String"); } break;
+      case LABCOMM_SAMPLE: { env.print("Sample"); } break;
       default: { env.print(getName()); } break;
     }
   }
@@ -848,10 +851,6 @@ aspect CS_Class {
 		    " not declared");
   }
 
-  public void SampleRefType.CS_emitType(CS_env env) {
-    env.print("Sample");
-  }
-
   public void VoidType.CS_emitType(CS_env env) {
     env.print("void");
   }
@@ -860,6 +859,7 @@ aspect CS_Class {
     switch (getToken()) {
       case LABCOMM_STRING: { env.print("String"); } break;
       case LABCOMM_BOOLEAN: { env.print("bool"); } break;
+      case LABCOMM_SAMPLE: { env.print("Sample"); } break;
       default: { env.print(getName()); } break;
     }
   }
diff --git a/compiler/C_CodeGen.jrag b/compiler/C_CodeGen.jrag
index fcef08d..0beef11 100644
--- a/compiler/C_CodeGen.jrag
+++ b/compiler/C_CodeGen.jrag
@@ -298,10 +298,6 @@ aspect C_Type {
     env.print("char " + name);
   }
 
-  public void SampleRefType.C_emitType(C_env env, String name) {
-    env.print("const struct labcomm_signature *" + name);
-  }
-
   public void PrimType.C_emitType(C_env env, String name) {
     switch (getToken()) {
       case LABCOMM_BOOLEAN: { env.print("uint8_t"); } break;
@@ -312,6 +308,9 @@ aspect C_Type {
       case LABCOMM_FLOAT: { env.print("float"); } break;
       case LABCOMM_DOUBLE: { env.print("double"); } break;
       case LABCOMM_STRING: { env.print("char*"); } break;
+      case LABCOMM_SAMPLE: { 
+        env.print("const struct labcomm_signature *"); 
+      } break;
     }
     env.print(" " + name);
   }
@@ -523,12 +522,17 @@ aspect C_Decoder {
   public void VoidType.C_emitDecoder(C_env env) {
   }
 
-  public void SampleRefType.C_emitDecoder(C_env env) {
-    env.println(env.qualid + " = labcomm_internal_decoder_index_to_signature(" +
-                             "r->decoder, labcomm"+env.verStr+"_read_int(r));");
-  }
   public void PrimType.C_emitDecoder(C_env env) {
-    env.println(env.qualid + " = labcomm"+env.verStr+"_read_" + getName() + "(r);");
+    env.print(env.qualid + " = ");
+    switch (getToken()) {
+      case LABCOMM_SAMPLE: { 
+        env.println("labcomm_internal_decoder_index_to_signature(" +
+                    "r->decoder, labcomm"+env.verStr+"_read_int(r));");
+      } break;
+      default: {
+        env.println("labcomm"+env.verStr+"_read_" + getName() + "(r);");
+      }; break;
+    }
   }
 
   public void UserType.C_emitDecoder(C_env env) {
@@ -604,9 +608,6 @@ aspect C_Decoder {
 		    " not declared");
   }
 
-  public void SampleRefType.C_emitDecoderDeallocation(C_env env) {
-  }
-
   public void PrimType.C_emitDecoderDeallocation(C_env env) {
     if (C_isDynamic()) {
       env.println("labcomm"+env.verStr+"_memory_free(r->memory, 1, " + 
@@ -759,12 +760,6 @@ aspect C_copy {
   public void VoidType.C_emitCopy(C_env env_src, C_env env_dst) {
   }
 
-  public void SampleRefType.C_emitCopy(C_env env_src, C_env env_dst) {
-    env_src.println(env_dst.accessor() + env_dst.qualid + " = " +
-                    env_src.accessor() + env_src.qualid + ";");
-  
-  }
-
   public void PrimType.C_emitCopy(C_env env_src, C_env env_dst) {
     if (C_isDynamic()) {
       env_src.println(String.format(
@@ -908,9 +903,6 @@ aspect C_copy {
   public void VoidType.C_emitCopyDeallocation(C_env env) {
   }
 
-  public void SampleRefType.C_emitCopyDeallocation(C_env env) {
-  }
-
   public void PrimType.C_emitCopyDeallocation(C_env env) {
     if (C_isDynamic()) {
       env.println("labcomm" + env.verStr + "_memory_free(mem, 1, " +
@@ -1065,16 +1057,19 @@ aspect C_Encoder {
     env.println("result = 0;");
   }
 
-  public void SampleRefType.C_emitEncoder(C_env env) {
-    env.println("result = labcomm"+env.verStr+"_write_int(w, " + 
-                "labcomm_internal_encoder_signature_to_index(w->encoder, " +
-                env.qualid + "));");
-    env.println("if (result != 0) { return result; }");
-  }
-
   public void PrimType.C_emitEncoder(C_env env) {
-    env.println("result = labcomm"+env.verStr+"_write_" + getName() + 
-                "(w, " + env.qualid + ");");
+    env.print("result = ");
+    switch (getToken()) {
+      case LABCOMM_SAMPLE: { 
+        env.println("labcomm"+env.verStr+"_write_int(w, " + 
+                    "labcomm_internal_encoder_signature_to_index(w->encoder, " +
+                    env.qualid + "));");
+      } break;
+      default: {
+        env.println("labcomm"+env.verStr+"_write_" + getName() + 
+                    "(w, " + env.qualid + ");");
+      } break;
+    }
     env.println("if (result != 0) { return result; }");
   }
 
@@ -1496,10 +1491,6 @@ aspect C_Sizeof {
     return 0;
   }
 
-  public int SampleRefType.C_fixedSizeof() {
-    return 4;
-  }
-
   public int PrimType.C_fixedSizeof() {
     switch (getToken()) {
       case LABCOMM_BOOLEAN: { return 1; } 
@@ -1509,6 +1500,7 @@ aspect C_Sizeof {
       case LABCOMM_LONG: { return 8; }
       case LABCOMM_FLOAT: { return 4; }
       case LABCOMM_DOUBLE: { return 8; }
+      case LABCOMM_SAMPLE: { return 4; }
       default: { 
     throw new Error(this.getClass().getName() + 
             ".C_fixedSizeof()" + 
diff --git a/compiler/Java_CodeGen.jrag b/compiler/Java_CodeGen.jrag
index cfa38cc..db2d562 100644
--- a/compiler/Java_CodeGen.jrag
+++ b/compiler/Java_CodeGen.jrag
@@ -645,6 +645,7 @@ aspect Java_Class {
       case LABCOMM_FLOAT: { env.print("e.encodeFloat"); } break;
       case LABCOMM_DOUBLE: { env.print("e.encodeDouble"); } break;
       case LABCOMM_STRING: { env.print("e.encodeString"); } break;
+      case LABCOMM_SAMPLE: { env.print("e.encodeSampleRef"); } break;
     }
     env.println("(" + name + ");");
   }
@@ -735,6 +736,7 @@ aspect Java_Class {
       case LABCOMM_FLOAT: { env.println("d.decodeFloat();"); } break;
       case LABCOMM_DOUBLE: { env.println("d.decodeDouble();"); } break;
       case LABCOMM_STRING: { env.println("d.decodeString();"); } break;
+      case LABCOMM_SAMPLE: { env.println("d.decodeSampleRef();"); } break;
     }
   }
 
@@ -819,6 +821,7 @@ aspect Java_Class {
   public void PrimType.Java_emitTypePrefix(Java_env env) {
     switch (getToken()) {
       case LABCOMM_STRING: { env.print("String"); } break;
+      case LABCOMM_SAMPLE: { env.print("Sample"); } break;
       default: { env.print(getName()); } break;
     }
   }
@@ -938,10 +941,6 @@ aspect Java_Class {
 		    " not declared");
   }
 
-  public void SampleRefType.Java_emitType(Java_env env) {
-    env.print("Sample");
-  }
-
   public void VoidType.Java_emitType(Java_env env) {
     env.print("void");
   }
@@ -949,6 +948,7 @@ aspect Java_Class {
   public void PrimType.Java_emitType(Java_env env) {
     switch (getToken()) {
       case LABCOMM_STRING: { env.print("String"); } break;
+      case LABCOMM_SAMPLE: { env.print("Sample"); } break;
       default: { env.print(getName()); } break;
     }
   }
diff --git a/compiler/LabCommParser.parser b/compiler/LabCommParser.parser
index e885fb2..f326d63 100644
--- a/compiler/LabCommParser.parser
+++ b/compiler/LabCommParser.parser
@@ -77,7 +77,6 @@ Type type =
   | user_type.u                     {: return u; :}
   | struct_type.s                   {: return s; :}
   | void_type.v                     {: return v; :}
-  | sample_ref_type.s               {: return s; :}
   ;
 
 PrimType prim_type =
@@ -97,6 +96,8 @@ PrimType prim_type =
       {: return new PrimType(DOUBLE, ASTNode.LABCOMM_DOUBLE); :}
   | STRING
       {: return new PrimType(STRING, ASTNode.LABCOMM_STRING); :}
+  | SAMPLE
+      {: return new PrimType(SAMPLE, ASTNode.LABCOMM_SAMPLE); :}
   ;
 
 UserType user_type =
@@ -111,10 +112,6 @@ VoidType void_type =
     VOID {: return new VoidType(); :} 
 ;
 
-SampleRefType sample_ref_type = 
-    SAMPLE {: return new SampleRefType(); :} 
-;
-
 List dim_list =
     dim.d                           {: return new List().add(d); :}
   | dim_list.l  dim.d               {: return l.add(d); :}
diff --git a/compiler/LabCommTokens.jrag b/compiler/LabCommTokens.jrag
index 741ef60..557714b 100644
--- a/compiler/LabCommTokens.jrag
+++ b/compiler/LabCommTokens.jrag
@@ -16,5 +16,6 @@ aspect LabCommTokens {
   public static final int ASTNode.LABCOMM_FLOAT =      0x25;
   public static final int ASTNode.LABCOMM_DOUBLE =     0x26;
   public static final int ASTNode.LABCOMM_STRING =     0x27;
+  public static final int ASTNode.LABCOMM_SAMPLE =     0x28;
 
 }
diff --git a/compiler/Python_CodeGen.jrag b/compiler/Python_CodeGen.jrag
index 0063b30..17203b8 100644
--- a/compiler/Python_CodeGen.jrag
+++ b/compiler/Python_CodeGen.jrag
@@ -152,10 +152,6 @@ aspect PythonTypes {
                     " not declared");
   }
 
-  public void SampleRefType.Python_genSignature(Python_env env) {
-    env.print("labcomm.SAMPLE_REF()");
-  }
-
   public void PrimType.Python_genSignature(Python_env env) {
     switch (getToken()) {
       case LABCOMM_BOOLEAN: { env.print("labcomm.BOOLEAN()"); } break;
@@ -166,6 +162,7 @@ aspect PythonTypes {
       case LABCOMM_FLOAT: { env.print("labcomm.FLOAT()"); } break;
       case LABCOMM_DOUBLE: { env.print("labcomm.DOUBLE()"); } break;
       case LABCOMM_STRING: { env.print("labcomm.STRING()"); } break;
+      case LABCOMM_SAMPLE: { env.print("labcomm.SAMPLE()"); } break;
     }
   }
 
diff --git a/lib/c/labcomm_encoder.c b/lib/c/labcomm_encoder.c
index 1318e5a..973ac78 100644
--- a/lib/c/labcomm_encoder.c
+++ b/lib/c/labcomm_encoder.c
@@ -231,9 +231,12 @@ int labcomm_internal_encoder_signature_to_index(
   struct labcomm_encoder *e, const struct labcomm_signature *signature)
 {
   /* writer_lock should be held at this point */
-  int index = labcomm_get_local_index(signature);
-  if (! LABCOMM_SIGNATURE_ARRAY_GET(e->sample_ref, int, index, 0)) {
-    index = 0;
+  int index = 0;
+  if (signature != NULL) {
+    index = labcomm_get_local_index(signature);
+    if (! LABCOMM_SIGNATURE_ARRAY_GET(e->sample_ref, int, index, 0)) {
+      index = 0;
+    }
   }
   return index;
 }
diff --git a/lib/c/test/test_labcomm_generated_encoding.c b/lib/c/test/test_labcomm_generated_encoding.c
index e6fff79..06e972d 100644
--- a/lib/c/test/test_labcomm_generated_encoding.c
+++ b/lib/c/test/test_labcomm_generated_encoding.c
@@ -221,7 +221,7 @@ int main(void)
                                       labcomm_signature_generated_encoding_R);
   labcomm_encoder_sample_ref_register(encoder, 
                                       labcomm_signature_generated_encoding_R);
-  EXPECT({0x03, 0x08, -1, 0x01, 'R', 0x04, 0x10, 0x01, 0x04, 0x03});
+  EXPECT({0x03, 0x08, -1, 0x01, 'R', 0x04, 0x10, 0x01, 0x04, 0x28});
 
   labcomm_encoder_ioctl(encoder, IOCTL_WRITER_RESET);
   // was: labcomm_encode_generated_encoding_V(encoder, &V);
diff --git a/lib/csharp/se/lth/control/labcomm/Decoder.cs b/lib/csharp/se/lth/control/labcomm/Decoder.cs
index 6f2086b..9bdbaa3 100644
--- a/lib/csharp/se/lth/control/labcomm/Decoder.cs
+++ b/lib/csharp/se/lth/control/labcomm/Decoder.cs
@@ -16,6 +16,7 @@ namespace se.lth.control.labcomm {
     double decodeDouble();
     String decodeString();
     int decodePacked32();
+    Sample decodeSampleRef();
 
   }
 
diff --git a/lib/csharp/se/lth/control/labcomm/DecoderChannel.cs b/lib/csharp/se/lth/control/labcomm/DecoderChannel.cs
index 60e94b1..26a3e65 100644
--- a/lib/csharp/se/lth/control/labcomm/DecoderChannel.cs
+++ b/lib/csharp/se/lth/control/labcomm/DecoderChannel.cs
@@ -146,5 +146,9 @@ namespace se.lth.control.labcomm {
 
       return (int) (res & 0xffffffff);
     }
+
+    public Sample decodeSampleRef() {
+      return null;
+    }
   }
 } 
diff --git a/lib/csharp/se/lth/control/labcomm/Encoder.cs b/lib/csharp/se/lth/control/labcomm/Encoder.cs
index f33af17..0badded 100644
--- a/lib/csharp/se/lth/control/labcomm/Encoder.cs
+++ b/lib/csharp/se/lth/control/labcomm/Encoder.cs
@@ -17,6 +17,7 @@ namespace se.lth.control.labcomm {
     void encodeDouble(double value);
     void encodeString(String value);
     void encodePacked32(Int64 value);
+    void encodeSampleRef(Sample value);
     
   }
 
diff --git a/lib/csharp/se/lth/control/labcomm/EncoderChannel.cs b/lib/csharp/se/lth/control/labcomm/EncoderChannel.cs
index 7ecd506..90c0ac8 100644
--- a/lib/csharp/se/lth/control/labcomm/EncoderChannel.cs
+++ b/lib/csharp/se/lth/control/labcomm/EncoderChannel.cs
@@ -117,5 +117,11 @@ namespace se.lth.control.labcomm {
     public void encodePacked32(Int64 value) {
       WritePacked32(bytes, value);
     }
+
+    public void encodeSampleRef(Sample value) {
+      WriteInt(0, 4);
+      throw new Exception("IMPLEMENT");
+    }
+
   }
 }
diff --git a/lib/java/se/lth/control/labcomm/Decoder.java b/lib/java/se/lth/control/labcomm/Decoder.java
index c3320f3..317133d 100644
--- a/lib/java/se/lth/control/labcomm/Decoder.java
+++ b/lib/java/se/lth/control/labcomm/Decoder.java
@@ -15,5 +15,6 @@ public interface Decoder {
   public double decodeDouble() throws IOException;
   public String decodeString() throws IOException;
   public int decodePacked32() throws IOException;
+  public Sample decodeSampleRef() throws IOException;
 
 }
diff --git a/lib/java/se/lth/control/labcomm/DecoderChannel.java b/lib/java/se/lth/control/labcomm/DecoderChannel.java
index c8d6925..276716d 100644
--- a/lib/java/se/lth/control/labcomm/DecoderChannel.java
+++ b/lib/java/se/lth/control/labcomm/DecoderChannel.java
@@ -172,5 +172,12 @@ public class DecoderChannel implements Decoder {
 
     return (int) (res & 0xffffffff);
   }
+
+  public Sample decodeSampleRef() throws IOException {
+    int index = in.readInt();
+    throw new IOException("IMPLEMENT");
+//    return null;
+  }
+    
 }
 
diff --git a/lib/java/se/lth/control/labcomm/Encoder.java b/lib/java/se/lth/control/labcomm/Encoder.java
index d6ef3e5..1d0f452 100644
--- a/lib/java/se/lth/control/labcomm/Encoder.java
+++ b/lib/java/se/lth/control/labcomm/Encoder.java
@@ -16,5 +16,6 @@ public interface Encoder {
   public void encodeDouble(double value) throws IOException;
   public void encodeString(String value) throws IOException;
   public void encodePacked32(long value) throws IOException;
+  public void encodeSampleRef(Sample value) throws IOException;
 
 }
diff --git a/lib/java/se/lth/control/labcomm/EncoderChannel.java b/lib/java/se/lth/control/labcomm/EncoderChannel.java
index 08ef0fc..91dade4 100644
--- a/lib/java/se/lth/control/labcomm/EncoderChannel.java
+++ b/lib/java/se/lth/control/labcomm/EncoderChannel.java
@@ -131,5 +131,12 @@ public class EncoderChannel implements Encoder {
       encodeByte((byte)(tmp[i] | (i!=0?0x80:0x00)));
     }
   }
+
+  public void encodeSampleRef(Sample value) throws IOException {
+    data.writeInt(0);
+    throw new IOException("IMPLEMENT");
+  }
+    
+
 }
 
diff --git a/lib/python/labcomm/LabComm.py b/lib/python/labcomm/LabComm.py
index dff29ce..f22707b 100644
--- a/lib/python/labcomm/LabComm.py
+++ b/lib/python/labcomm/LabComm.py
@@ -156,8 +156,7 @@ DEFAULT_VERSION = "LabComm2014"
 # Allowed packet tags
 i_VERSION     = 0x01
 i_SAMPLE_DEF  = 0x02
-i_TYPE_DEF    = 0x03
-i_TYPE_BINDING= 0x04
+i_SAMPLE_REF  = 0x03
 i_PRAGMA      = 0x3f
 i_USER        = 0x40 # ..0xffffffff
 
@@ -173,6 +172,7 @@ i_LONG    = 0x24
 i_FLOAT   = 0x25
 i_DOUBLE  = 0x26
 i_STRING  = 0x27
+i_SAMPLE  = 0x28
 
 
 # Version testing
@@ -189,7 +189,7 @@ class length_encoder:
         self.data += data
 
     def __enter__(self):
-        return Encoder(self, None)
+        return Encoder(writer=self, version=None, codec=self.encoder)
 
     def __exit__(self, type, value, traceback):
         if usePacketLength(self.version):
@@ -334,18 +334,38 @@ class STRING(primitive):
     def __repr__(self):
         return "labcomm.STRING()"
 
+class SAMPLE(primitive):
+
+    def encode_decl(self, encoder):
+        return encoder.encode_type(i_SAMPLE)
+
+    def encode(self, encoder, value):
+        return encoder.encode_int(encoder.ref_to_index.get(value, 0))
+    
+    def decode(self, decoder, obj=None):
+        return decoder.decode_ref()
+
+    def new_instance(self):
+        return ""
+
+    def __eq__(self, other):
+        return self.__class__ == other.__class__
+
+    def __repr__(self):
+        return "labcomm.SAMPLE()"
+
 #
 # Aggregate types
 #
-class sample(object):
-    def __init__(self, name, decl):
+class sample_def_or_ref(object):
+    def __init__(self, name=None, decl=None):
         self.name = name
         self.decl = decl
 
     def encode_decl(self, encoder):
-        encoder.encode_type(i_SAMPLE_DEF)
+        encoder.encode_type(self.type_index)
         with length_encoder(encoder) as e1:
-            e1.encode_type(encoder.decl_to_index[self])
+            e1.encode_type(self.get_index(encoder))
             e1.encode_string(self.name)
             with length_encoder(e1) as e2:
                 self.decl.encode_decl(e2)
@@ -360,8 +380,8 @@ class sample(object):
             length = decoder.decode_packed32()
         decl = decoder.decode_decl()
         result = self.__class__.__new__(self.__class__)
-        result.__init__(name, decl)
-        decoder.add_decl(result, index)
+        result.__init__(name=name, decl=decl)
+        self.add_index(decoder, index, result)
         return result
 
     def decode(self, decoder, obj=None):
@@ -372,15 +392,58 @@ class sample(object):
     def new_instance(self):
         return self.decl.new_instance()
 
+    def __eq__(self, other):
+        return (type(self) == type(other) and 
+                self.name == other.name and
+                self.decl == other.decl)
+        
+    def __ne__(self, other):
+        return not self == other
+
     def __repr__(self):
-        return "sample('%s', %s)" % (self.name, self.decl)
+        return "%s('%s', %s)" % (self.type_name, self.name, self.decl)
+
+class sample_def(sample_def_or_ref):
+    type_index = i_SAMPLE_DEF
+    type_name = 'sample'
+
+    def get_index(self, encoder):
+        return encoder.decl_to_index[self]
+
+    def add_index(self, decoder, index, decl):
+        decoder.add_decl(decl, index)
 
+class sample_ref(sample_def_or_ref):
+    type_index = i_SAMPLE_REF
+    type_name = 'sample_ref'
+    
+    def __init__(self, name=None, decl=None, sample=None):
+        self.name = name
+        self.decl = decl
+        if sample == None and name != None and decl != None:
+            self.sample = sample_def(name, decl)
+        else:
+            self.sample = sample
+
+    def get_index(self, encoder):
+        return encoder.ref_to_index[self.sample]
+
+    def add_index(self, decoder, index, decl):
+        decoder.add_ref(decl, index)
 
 class array(object):
     def __init__(self, indices, decl):
         self.indices = indices
         self.decl = decl
         
+    def __eq__(self, other):
+        return (type(self) == type(other) and 
+                self.indices == other.indices and
+                self.decl == other.decl)
+        
+    def __ne__(self, other):
+        return not self == other
+
     def encode_decl(self, encoder):
         encoder.encode_type(i_ARRAY)
         encoder.encode_packed32(len(self.indices))
@@ -545,7 +608,8 @@ class struct:
         result += "\n])"
         return result
 
-SAMPLE_DEF = sample(None, None)
+SAMPLE_DEF = sample_def()
+SAMPLE_REF = sample_ref()
 
 ARRAY = array(None, None)
 STRUCT = struct({})
@@ -564,17 +628,23 @@ class anonymous_object(dict):
             return self[name]
 
 class Codec(object):
-    def __init__(self):
-        self.type_to_name = {}
-        self.name_to_type = {}
-        self.index_to_decl = {}
-        self.decl_to_index = {}
-        self.name_to_decl = {}
-        self.decl_index = i_USER
-        self.predefined_types()
+    def __init__(self, codec=None):
+        self.type_to_name = codec and codec.type_to_name or {}
+        self.name_to_type = codec and codec.name_to_type or {}
+        self.index_to_decl = codec and codec.index_to_decl or {}
+        self.decl_to_index = codec and codec.decl_to_index  or {}
+        self.name_to_decl = codec and codec.name_to_decl  or {}
+        self.index_to_ref = codec and codec.index_to_ref or {}
+        self.ref_to_index = codec and codec.ref_to_index or {}
+        self.name_to_ref = codec and codec.name_to_ref or {}
+        self.decl_index = codec and codec.decl_index or i_USER
+        self.ref_index = codec and codec.ref_index or i_USER
+        if not codec:
+            self.predefined_types()
 
     def predefined_types(self):
         self.add_decl(SAMPLE_DEF, i_SAMPLE_DEF)
+        self.add_decl(SAMPLE_REF, i_SAMPLE_REF)
 
         self.add_decl(ARRAY, i_ARRAY)
         self.add_decl(STRUCT, i_STRUCT)
@@ -587,6 +657,7 @@ class Codec(object):
         self.add_decl(FLOAT(), i_FLOAT)
         self.add_decl(DOUBLE(), i_DOUBLE)
         self.add_decl(STRING(), i_STRING)
+        self.add_decl(SAMPLE(), i_SAMPLE)
         
     def add_decl(self, decl, index=0):
         if index == 0:
@@ -599,6 +670,17 @@ class Codec(object):
         except:
             pass
         
+    def add_ref(self, ref, index=0):
+        if index == 0:
+            index = self.ref_index
+            self.ref_index += 1
+        self.index_to_ref[index] = ref.sample
+        self.ref_to_index[ref.sample] = index
+        try:
+            self.name_to_ref[ref.sample.name] = ref.sample
+        except:
+            pass
+
     def add_binding(self, name, decl):
         self.type_to_name[decl] = name
         self.name_to_type[name] = decl
@@ -615,8 +697,8 @@ class Codec(object):
         
 
 class Encoder(Codec):
-    def __init__(self, writer, version=DEFAULT_VERSION):
-        super(Encoder, self).__init__()
+    def __init__(self, writer, version=DEFAULT_VERSION, codec=None):
+        super(Encoder, self).__init__(codec)
         self.writer = writer
         self.version = version
         if self.version in [ "LabComm2014" ]:
@@ -637,6 +719,13 @@ class Encoder(Codec):
             decl.encode_decl(self)
             self.writer.mark()
  
+    def add_ref(self, decl, index=0):
+        ref = sample_ref(name=decl.name, decl=decl.decl, sample=decl)
+        super(Encoder, self).add_ref(ref, index)
+        if index == 0:
+            ref.encode_decl(self)
+            self.writer.mark()
+ 
     def encode(self, object, decl=None):
         if decl == None:
             name = self.type_to_name[object.__class__]
@@ -745,12 +834,11 @@ class Decoder(Codec):
         if index == i_SAMPLE_DEF:
             decl = self.index_to_decl[index].decode_decl(self)
             value = None
-        elif index == i_TYPE_DEF:
-            print "Got type_def, skipping %d bytes" % length
-            self.skip(length)
-        elif index == i_TYPE_BINDING:
-            print "Got type_binding, skipping %d bytes" % length
-            self.skip(length)
+        elif index == i_SAMPLE_REF:
+            decl = self.index_to_decl[index].decode_decl(self)
+            value = None
+        elif index < i_USER:
+            raise exception("Invalid type index %d" % index)
         else:
             decl = self.index_to_decl[index]
             value = decl.decode(self)
@@ -809,9 +897,12 @@ class Decoder(Codec):
         return self.unpack("!d")
     
     def decode_string(self):
-        length =  self.decode_packed32()
+        length = self.decode_packed32()
         return self.unpack("!%ds" % length).decode("utf8")
 
+    def decode_ref(self):
+        index = self.decode_int()
+        return self.index_to_ref.get(index, None)
 
 class signature_reader:
     def __init__(self, signature):
diff --git a/lib/python/labcomm/__init__.py b/lib/python/labcomm/__init__.py
index 9eee0b7..0c95e1d 100644
--- a/lib/python/labcomm/__init__.py
+++ b/lib/python/labcomm/__init__.py
@@ -6,7 +6,9 @@ from labcomm.StreamWriter import StreamWriter
 Decoder = labcomm.LabComm.Decoder
 Encoder = labcomm.LabComm.Encoder
 
-sample = labcomm.LabComm.sample
+sample = labcomm.LabComm.sample_def
+sample_def = labcomm.LabComm.sample_def
+sample_ref = labcomm.LabComm.sample_ref
 
 array = labcomm.LabComm.array
 struct = labcomm.LabComm.struct
@@ -20,6 +22,7 @@ LONG = labcomm.LabComm.LONG
 FLOAT = labcomm.LabComm.FLOAT
 DOUBLE = labcomm.LabComm.DOUBLE
 STRING = labcomm.LabComm.STRING
+SAMPLE = labcomm.LabComm.SAMPLE
 
 decl_from_signature = labcomm.LabComm.decl_from_signature
 
diff --git a/test/relay_gen_c.py b/test/relay_gen_c.py
index 837cf72..dd3f894 100755
--- a/test/relay_gen_c.py
+++ b/test/relay_gen_c.py
@@ -74,8 +74,10 @@ if __name__ == '__main__':
     for func,arg,stype in sample:
         result.extend(split_match('^[^|]*\|(.*)$', """
           |  labcomm_encoder_register_%(func)s(e);
+          |  labcomm_encoder_sample_ref_register(e, labcomm_signature_%(func)s);
           |  labcomm_decoder_register_%(func)s(d, handle_%(func)s, e);
-        """ % { 'func': func, 'arg': arg }))
+          |  labcomm_decoder_sample_ref_register(d, labcomm_signature_%(func)s);
+       """ % { 'func': func, 'arg': arg }))
     result.extend(split_match('^[^|]*\|(.*)$', """
       |  labcomm_decoder_run(d);
       |  return 0;
diff --git a/test/relay_gen_java.py b/test/relay_gen_java.py
index 40f584e..757dc01 100755
--- a/test/relay_gen_java.py
+++ b/test/relay_gen_java.py
@@ -28,6 +28,7 @@ if __name__ == '__main__':
       |import java.io.IOException;
       |import se.lth.control.labcomm.DecoderChannel;
       |import se.lth.control.labcomm.EncoderChannel;
+      |import se.lth.control.labcomm.Sample;
       |
       |public class java_relay implements
     """))
diff --git a/test/test_encoder_decoder.py b/test/test_encoder_decoder.py
index f845bd6..7f0fb8b 100755
--- a/test/test_encoder_decoder.py
+++ b/test/test_encoder_decoder.py
@@ -12,77 +12,6 @@ import subprocess
 import sys
 import threading
 
-def generate(decl):
-    if decl.__class__ == labcomm.sample:
-        result = []
-        for values in generate(decl.decl):
-            result.append((decl, values))
-        return result
-
-    elif decl.__class__ == labcomm.struct:
-        result = []
-        if len(decl.field) == 0:
-            result.append({})
-        else:
-            values1 = generate(decl.field[0][1])
-            values2 = generate(labcomm.struct(decl.field[1:]))
-            for v1 in values1:
-                for v2 in values2:
-                    v = dict(v2)
-                    v[decl.field[0][0]] = v1
-                    result.append(v)
-        return result
-    
-    elif decl.__class__ == labcomm.array:
-        if len(decl.indices) == 1:
-            values = generate(decl.decl)
-            if decl.indices[0] == 0:
-                lengths = [0, 1, 2]
-            else:
-                lengths = [ decl.indices[0] ]
-        else:
-            values = generate(labcomm.array(decl.indices[1:],decl.decl))
-            if decl.indices[0] == 0:
-                lengths = [1, 2]
-            else:
-                lengths = [ decl.indices[0] ]
-        result = []
-        for v in values:
-            for i in lengths:
-                element = []
-                for j in range(i):
-                    element.append(v)
-                result.append(element)
-        return result
-
-    elif decl.__class__ == labcomm.BOOLEAN:
-        return [False, True]
-
-    elif decl.__class__ == labcomm.BYTE:
-        return [0, 127, 128, 255]
-
-    elif decl.__class__ == labcomm.SHORT:
-        return [-32768, 0, 32767]
-
-    elif decl.__class__ == labcomm.INTEGER:
-        return [-2147483648, 0, 2147483647]
-
-    elif decl.__class__ == labcomm.LONG:
-        return [-9223372036854775808, 0, 9223372036854775807]
-
-    elif decl.__class__ == labcomm.FLOAT:
-        def tofloat(v):
-            return struct.unpack('f', struct.pack('f', v))[0]
-        return [tofloat(-math.pi), 0.0, tofloat(math.pi)]
-
-    elif decl.__class__ == labcomm.DOUBLE:
-        return [-math.pi, 0.0, math.pi]
-
-    elif decl.__class__ == labcomm.STRING:
-        return ['string', u'sträng' ]
-
-    print>>sys.stderr, decl
-    raise Exception("unhandled decl %s" % decl.__class__)
 
 def labcomm_compile(lc, name, args):
     for lang in [ 'c', 'csharp', 'java', 'python']:
@@ -117,30 +46,146 @@ class Test:
         self.signatures = signatures
         pass
 
+    def generate(self, decl):
+        if decl.__class__ == labcomm.sample:
+            result = []
+            for values in self.generate(decl.decl):
+                result.append((decl, values))
+            return result
+    
+        elif decl.__class__ == labcomm.struct:
+            result = []
+            if len(decl.field) == 0:
+                result.append({})
+            else:
+                values1 = self.generate(decl.field[0][1])
+                values2 = self.generate(labcomm.struct(decl.field[1:]))
+                for v1 in values1:
+                    for v2 in values2:
+                        v = dict(v2)
+                        v[decl.field[0][0]] = v1
+                        result.append(v)
+            return result
+        
+        elif decl.__class__ == labcomm.array:
+            if len(decl.indices) == 1:
+                values = self.generate(decl.decl)
+                if decl.indices[0] == 0:
+                    lengths = [0, 1, 2]
+                else:
+                    lengths = [ decl.indices[0] ]
+            else:
+                values = self.generate(labcomm.array(decl.indices[1:],
+                                                     decl.decl))
+                if decl.indices[0] == 0:
+                    lengths = [1, 2]
+                else:
+                    lengths = [ decl.indices[0] ]
+            result = []
+            for v in values:
+                for i in lengths:
+                    element = []
+                    for j in range(i):
+                        element.append(v)
+                    result.append(element)
+            return result
+    
+        elif decl.__class__ == labcomm.BOOLEAN:
+            return [False, True]
+    
+        elif decl.__class__ == labcomm.BYTE:
+            return [0, 127, 128, 255]
+    
+        elif decl.__class__ == labcomm.SHORT:
+            return [-32768, 0, 32767]
+    
+        elif decl.__class__ == labcomm.INTEGER:
+            return [-2147483648, 0, 2147483647]
+    
+        elif decl.__class__ == labcomm.LONG:
+            return [-9223372036854775808, 0, 9223372036854775807]
+    
+        elif decl.__class__ == labcomm.FLOAT:
+            def tofloat(v):
+                return struct.unpack('f', struct.pack('f', v))[0]
+            return [tofloat(-math.pi), 0.0, tofloat(math.pi)]
+    
+        elif decl.__class__ == labcomm.DOUBLE:
+            return [-math.pi, 0.0, math.pi]
+    
+        elif decl.__class__ == labcomm.STRING:
+            return ['string', u'sträng' ]
+    
+        elif decl.__class__ == labcomm.SAMPLE:
+            return [ s for n,s in self.signatures ]
+    
+        print>>sys.stderr, decl
+        raise Exception("unhandled decl %s" % decl.__class__)
+
+    def uses_refs(self, decls):
+        for decl in decls:
+            if decl.__class__ == labcomm.sample:
+                if self.uses_refs([ decl.decl ]):
+                    return True
+    
+            elif decl.__class__ == labcomm.struct:
+                if self.uses_refs([ d for n,d in decl.field ]):
+                    return True
+        
+            elif decl.__class__ == labcomm.array:
+                if self.uses_refs([ decl.decl ]):
+                    return True
+
+            elif decl.__class__ == labcomm.SAMPLE:
+                return True
+
+        return False
+        
+
     def run(self):
         print>>sys.stderr, 'Testing', self.program
         p = subprocess.Popen(self.program, 
                              stdin=subprocess.PIPE,
-                             stdout=subprocess.PIPE)
+                             stdout=subprocess.PIPE,
+                             stderr=sys.stderr)
         self.expected = None
         self.failed = False
-        self.next = threading.Semaphore(0)
+        self.next = threading.Condition()
         decoder = threading.Thread(target=self.decode, args=(p.stdout,))
         decoder.start()
         encoder = labcomm.Encoder(labcomm.StreamWriter(p.stdin))
         for name,signature in self.signatures:
             encoder.add_decl(signature)
             pass
+        if self.uses_refs([ s for n,s in self.signatures ]):
+            for name,signature in self.signatures:
+                encoder.add_ref(signature)
         for name,signature in self.signatures:
             print>>sys.stderr, "Checking", name,
-            for decl,value in generate(signature):
-                sys.stdout.write('.')
+            for decl,value in self.generate(signature):
+                sys.stderr.write('.')
                 #print name,decl,value,value.__class__
-                self.expected = value
-                encoder.encode(value, decl)
                 self.next.acquire()
-                if self.failed:
-                    p.terminate()
+                self.received_value = None
+                self.received_decl = None
+                encoder.encode(value, decl)
+                self.next.wait(2)
+                self.next.release()
+                if p.poll() != None:
+                    print>>sys.stderr, "Failed with:", p.poll()
+                    self.failed = True
+                elif value != self.received_value:
+                    print>>sys.stderr, "Coding error"
+                    print>>sys.stderr,value == self.received_value
+                    print>>sys.stderr, "Got:     ", self.received_value 
+                    print>>sys.stderr, "         ", self.received_decl 
+                    print>>sys.stderr, "Expected:", value
+                    print>>sys.stderr, "         ", decl
+                    self.failed = True
+                    
+                if self.failed: 
+                    if p.poll() == None:
+                        p.terminate()
                     exit(1)
                 pass
             print>>sys.stderr
@@ -157,9 +202,11 @@ class Test:
             while True:
                 value,decl = decoder.decode()
                 if value != None:
-                    if value != self.expected:
-                        print>>sys.stderr, "Coding error", value, self.expected, decl
-                        self.failed = True
+                    self.next.acquire()
+                    self.received_value = value
+                    self.received_decl = decl
+                    self.expected = None
+                    self.next.notify_all()
                     self.next.release()
                 pass
             pass
-- 
GitLab