snakenet.rb: You can choose between using float or integer weights. floats seem to give better results and the speed didn't improve a lot.
This commit is contained in:
parent
b644006036
commit
a902addf94
@ -4,6 +4,7 @@ require 'pp'
|
|||||||
require 'thread/pool'
|
require 'thread/pool'
|
||||||
|
|
||||||
GAMES_PER_ROUND = 50
|
GAMES_PER_ROUND = 50
|
||||||
|
FLOAT = true
|
||||||
|
|
||||||
class Game
|
class Game
|
||||||
WIDTH = 16
|
WIDTH = 16
|
||||||
@ -177,10 +178,10 @@ class AI
|
|||||||
@rounds = 1
|
@rounds = 1
|
||||||
@id = rand(0xFFFFFF)
|
@id = rand(0xFFFFFF)
|
||||||
if w==nil
|
if w==nil
|
||||||
@weights = Array.new(network_size()) { rand() * 2.0 - 1.0 }
|
@weights = Array.new(network_size()) { FLOAT ? rand() * 2.0 - 1.0 : rand(256) - 128 }
|
||||||
puts "Initialized with random values: #{@weights}" if @debug
|
puts "Initialized with random values: #{@weights}" if @debug
|
||||||
else
|
else
|
||||||
if w[0].is_a? Integer
|
if w[0].is_a?(Integer) && FLOAT
|
||||||
@weights = w.map{|s| s.to_s(16).rjust(8, "0").split("").each_slice(2).to_a.map(&:join).map{|s| s.to_i(16).chr}.join.unpack("g")}.flatten
|
@weights = w.map{|s| s.to_s(16).rjust(8, "0").split("").each_slice(2).to_a.map(&:join).map{|s| s.to_i(16).chr}.join.unpack("g")}.flatten
|
||||||
else
|
else
|
||||||
@weights = w
|
@weights = w
|
||||||
@ -220,7 +221,7 @@ class AI
|
|||||||
(1...(NETWORK_LAYOUT.count)).each do |i|
|
(1...(NETWORK_LAYOUT.count)).each do |i|
|
||||||
c_in = NETWORK_LAYOUT[i-1]
|
c_in = NETWORK_LAYOUT[i-1]
|
||||||
c_out = NETWORK_LAYOUT[i]
|
c_out = NETWORK_LAYOUT[i]
|
||||||
outputs = Array.new(c_out){0.0}
|
outputs = Array.new(c_out){FLOAT ? 0.0 : 0}
|
||||||
(0...c_out).each do |o|
|
(0...c_out).each do |o|
|
||||||
(0...c_in).each do |i|
|
(0...c_in).each do |i|
|
||||||
outputs[o] += inputs[i] * @weights[x]
|
outputs[o] += inputs[i] * @weights[x]
|
||||||
@ -253,33 +254,53 @@ class AI
|
|||||||
# w[i2] = temp
|
# w[i2] = temp
|
||||||
if action==0 #change single value
|
if action==0 #change single value
|
||||||
i = rand(network_size())
|
i = rand(network_size())
|
||||||
diff = rand() * 0.2 - 0.1
|
diff = FLOAT ? rand() * 0.2 - 0.1 : rand(256) - 128
|
||||||
w2 = w.dup
|
w2 = w.dup
|
||||||
w[i] += diff
|
w[i] += diff
|
||||||
w[i] = 1.0 if w[i]>1.0
|
if FLOAT
|
||||||
w[i] = -1.0 if w[i]<-1.0
|
w[i] = 1.0 if w[i]>1.0
|
||||||
|
w[i] = -1.0 if w[i]<-1.0
|
||||||
|
else
|
||||||
|
w[i] = 127 if w[i]>127
|
||||||
|
w[i] = -128 if w[i]<-128
|
||||||
|
end
|
||||||
w2[i] -= diff
|
w2[i] -= diff
|
||||||
w2[i] = 1.0 if w2[i]>1.0
|
if FLOAT
|
||||||
w2[i] = -1.0 if w2[i]<-1.0
|
w2[i] = 1.0 if w2[i]>1.0
|
||||||
|
w2[i] = -1.0 if w2[i]<-1.0
|
||||||
|
else
|
||||||
|
w2[i] = 127 if w2[i]>127
|
||||||
|
w2[i] = -128 if w2[i]<-128
|
||||||
|
end
|
||||||
return [AI.new(w), AI.new(w2)]
|
return [AI.new(w), AI.new(w2)]
|
||||||
elsif action==1 #invert single value
|
elsif action==1 #invert single value
|
||||||
i = rand(network_size())
|
i = rand(network_size())
|
||||||
w[i] *= -1.0
|
w[i] *= FLOAT ? -1.0 : -1
|
||||||
elsif action==2
|
elsif action==2
|
||||||
(0...network_size()).each do |i|
|
(0...network_size()).each do |i|
|
||||||
w[i] = rand() * 2 - 1.0 if rand(5)==0
|
w[i] = (FLOAT ? rand() * 2 - 1.0 : rand(256) - 128) if rand(5)==0
|
||||||
end
|
end
|
||||||
else #change multiple values
|
else #change multiple values
|
||||||
w2 = w.dup
|
w2 = w.dup
|
||||||
(0...network_size()).each do |i|
|
(0...network_size()).each do |i|
|
||||||
if (rand(5)==0)
|
if (rand(5)==0)
|
||||||
diff = rand() * 0.2 - 0.1
|
diff = FLOAT ? rand() * 0.2 - 0.1 : rand(256) - 128
|
||||||
w[i] += diff
|
w[i] += diff
|
||||||
w[i] = 1.0 if w[i]>1.0
|
if FLOAT
|
||||||
w[i] = -1.0 if w[i]<-1.0
|
w[i] = 1.0 if w[i]>1.0
|
||||||
|
w[i] = -1.0 if w[i]<-1.0
|
||||||
|
else
|
||||||
|
w[i] = 127 if w[i]>127
|
||||||
|
w[i] = -128 if w[i]<-128
|
||||||
|
end
|
||||||
w2[i] -= diff
|
w2[i] -= diff
|
||||||
w2[i] = 1.0 if w2[i]>1.0
|
if FLOAT
|
||||||
w2[i] = -1.0 if w2[i]<-1.0
|
w2[i] = 1.0 if w2[i]>1.0
|
||||||
|
w2[i] = -1.0 if w2[i]<-1.0
|
||||||
|
else
|
||||||
|
w2[i] = 127 if w2[i]>127
|
||||||
|
w2[i] = -128 if w2[i]<-128
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return [AI.new(w), AI.new(w2)]
|
return [AI.new(w), AI.new(w2)]
|
||||||
@ -303,13 +324,17 @@ class AI
|
|||||||
w = @weights.dup
|
w = @weights.dup
|
||||||
w2 = ai.weights
|
w2 = ai.weights
|
||||||
(0...network_size()).each do |i|
|
(0...network_size()).each do |i|
|
||||||
w[i] = (w[i] + w2[i]) / 2.0
|
w[i] = (w[i] + w2[i]) / (FLOAT ? 2.0 : 2)
|
||||||
end
|
end
|
||||||
return AI.new(w)
|
return AI.new(w)
|
||||||
end
|
end
|
||||||
|
|
||||||
def dump
|
def dump
|
||||||
puts "const uint32_t _weights[#{network_size()}] = {#{@weights.map{|x| "0x" + [x].pack('g').split("").map(&:ord).map{|i| i.to_s(16).rjust(2, '0')}.join}.join(", ")}};"
|
if FLOAT
|
||||||
|
puts "const uint32_t _weights[#{network_size()}] = {#{@weights.map{|x| "0x" + [x].pack('g').split("").map(&:ord).map{|i| i.to_s(16).rjust(2, '0')}.join}.join(", ")}};"
|
||||||
|
else
|
||||||
|
puts "const int8_t _weights[#{network_size()}] = {#{@weights.join(", ")}};"
|
||||||
|
end
|
||||||
#puts "Simplified: #{simplified}"
|
#puts "Simplified: #{simplified}"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user