bin: Fixed float decoding bug.
[coe.git] / coe / bin.py
index 0aa6e5a..e9efc80 100644 (file)
@@ -1,3 +1,4 @@
+import io, math
 from . import data
 
 T_END = 0
@@ -70,6 +71,30 @@ class encoder(object):
         dst.write(text.encode("utf-8"))
         dst.write(b'\0')
 
+    @staticmethod
+    def writefloat(dst, x):
+        if x == 0.0:
+            mnt, exp = (0, 0)
+        elif math.isinf(x):
+            if x > 0:
+                mnt, exp = (0, 2)
+            else:
+                mnt, exp = (0, 3)
+        elif math.isnan(x):
+            mnt, exp = (0, 4)
+        else:
+            mnt, exp = math.frexp(x)
+            mnt *= 2
+            exp -= 1
+            while mnt != int(mnt):
+                mnt *= 2
+            mnt = int(mnt)
+        buf = bytearray()
+        buf.extend(encoder.encint(mnt))
+        buf.extend(encoder.encint(exp))
+        dst.write(encoder.encint(len(buf)))
+        dst.write(buf)
+
     def dumpseq(self, dst, seq):
         for v in seq:
             self.dump(dst, v)
@@ -89,9 +114,9 @@ class encoder(object):
             return
         if datum == None:
             self.writetag(dst, T_NIL, 0, None)
-        elif datum == False:
+        elif datum is False:
             self.writetag(dst, T_NIL, NIL_FALSE, None)
-        elif datum == True:
+        elif datum is True:
             self.writetag(dst, T_NIL, NIL_TRUE, None)
         elif isinstance(datum, int):
             self.writetag(dst, T_INT, 0, None)
@@ -103,6 +128,9 @@ class encoder(object):
             self.writetag(dst, T_BIT, 0, datum)
             dst.write(self.encint(len(datum)))
             dst.write(datum)
+        elif isinstance(datum, float):
+            self.writetag(dst, T_BIT, BIT_BFLOAT, datum)
+            self.writefloat(dst, datum)
         elif isinstance(datum, data.symbol):
             if datum.ns == "":
                 self.writetag(dst, T_STR, STR_SYM, datum)
@@ -152,9 +180,6 @@ class referror(fmterror):
     def __init__(self):
         super().__init__("bad backref")
 
-class namedtype(type):
-    pass
-
 class decoder(object):
     def __init__(self):
         self.reftab = []
@@ -225,6 +250,9 @@ class decoder(object):
                 return buf
             buf[key] = self.loadtagged(fp, tag)
 
+    def makeobjtype(self, nm):
+        return data.namedtype.make(str(nm), (data.obj, object), {}, typename=nm)
+
     def loadobj(self, fp, ref=False):
         if ref:
             refid = len(self.reftab)
@@ -232,8 +260,7 @@ class decoder(object):
         nm = self.load(fp)
         typ = self.namedtypes.get(nm)
         if typ is None:
-            typ = self.namedtypes[nm] = namedtype(str(nm), (data.obj, object), {})
-            typ.typename = nm
+            typ = self.namedtypes[nm] = self.makeobjtype(nm)
         ret = typ()
         if ref:
             self.reftab[refid] = ret
@@ -260,16 +287,34 @@ class decoder(object):
                 return self.reftab[idx]
             return self.addref(self.loadint(fp))
         elif pri == T_STR:
-            ret = self.addref(self.loadstr(fp))
+            ret = self.loadstr(fp)
             if sec == STR_SYM:
-                return data.symbol.get("", ret)
-            return ret
+                return self.addref(data.symbol.get("", ret))
+            return self.addref(ret)
         elif pri == T_BIT:
             ln = self.loadint(fp)
-            ret = self.addref(fp.read(ln))
+            ret = fp.read(ln)
             if len(ret) < ln:
                 raise eoferror()
-            return ret
+            if sec == BIT_BFLOAT:
+                buf = io.BytesIO(ret)
+                mnt = self.loadint(buf)
+                exp = self.loadint(buf)
+                if mnt == 0:
+                    if exp == 0:
+                        ret = 0.0
+                    elif exp == 1:
+                        ret = -0.0
+                    elif exp == 2:
+                        ret = float("inf")
+                    elif exp == 3:
+                        ret = -float("inf")
+                    else:
+                        ret = float("nan")
+                else:
+                    ret = math.ldexp(mnt, exp - (mnt.bit_length() - 1))
+                return self.addref(ret)
+            return self.addref(ret)
         elif pri == T_NIL:
             if sec == NIL_TRUE:
                 return self.addref(True)
@@ -293,4 +338,4 @@ class decoder(object):
         return self.loadtagged(fp, tag)
 
 def load(fp):
-    decoder().load(fp)
+    return decoder().load(fp)